diff --git a/src/Api/Controllers/CollectionsController.cs b/src/Api/Controllers/CollectionsController.cs index b8d83da706..419ee8d816 100644 --- a/src/Api/Controllers/CollectionsController.cs +++ b/src/Api/Controllers/CollectionsController.cs @@ -20,23 +20,26 @@ public class CollectionsController : Controller private readonly ICollectionService _collectionService; private readonly IDeleteCollectionCommand _deleteCollectionCommand; private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; private readonly IAuthorizationService _authorizationService; + private readonly ICurrentContext _currentContext; + private readonly IBulkAddCollectionAccessCommand _bulkAddCollectionAccessCommand; public CollectionsController( ICollectionRepository collectionRepository, ICollectionService collectionService, IDeleteCollectionCommand deleteCollectionCommand, IUserService userService, + IAuthorizationService authorizationService, ICurrentContext currentContext, - IAuthorizationService authorizationService) + IBulkAddCollectionAccessCommand bulkAddCollectionAccessCommand) { _collectionRepository = collectionRepository; _collectionService = collectionService; _deleteCollectionCommand = deleteCollectionCommand; _userService = userService; - _currentContext = currentContext; _authorizationService = authorizationService; + _currentContext = currentContext; + _bulkAddCollectionAccessCommand = bulkAddCollectionAccessCommand; } [HttpGet("{id}")] @@ -190,6 +193,29 @@ public class CollectionsController : Controller await _collectionRepository.UpdateUsersAsync(collection.Id, model?.Select(g => g.ToSelectionReadOnly())); } + [HttpPost("bulk-access")] + public async Task PostBulkCollectionAccess([FromBody] BulkCollectionAccessRequestModel model) + { + var collections = await _collectionRepository.GetManyByManyIdsAsync(model.CollectionIds); + + if (collections.Count != model.CollectionIds.Count()) + { + throw new NotFoundException("One or more collections not found."); + } + + var result = await _authorizationService.AuthorizeAsync(User, collections, CollectionOperations.ModifyAccess); + + if (!result.Succeeded) + { + throw new NotFoundException(); + } + + await _bulkAddCollectionAccessCommand.AddAccessAsync( + collections, + model.Users?.Select(u => u.ToSelectionReadOnly()).ToList(), + model.Groups?.Select(g => g.ToSelectionReadOnly()).ToList()); + } + [HttpDelete("{id}")] [HttpPost("{id}/delete")] public async Task Delete(Guid orgId, Guid id) diff --git a/src/Api/Models/Request/BulkCollectionAccessRequestModel.cs b/src/Api/Models/Request/BulkCollectionAccessRequestModel.cs new file mode 100644 index 0000000000..8076d8ea5a --- /dev/null +++ b/src/Api/Models/Request/BulkCollectionAccessRequestModel.cs @@ -0,0 +1,8 @@ +namespace Bit.Api.Models.Request; + +public class BulkCollectionAccessRequestModel +{ + public IEnumerable CollectionIds { get; set; } + public IEnumerable Groups { get; set; } + public IEnumerable Users { get; set; } +} diff --git a/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs b/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs index 4f4b80a4b4..503883fc57 100644 --- a/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs +++ b/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs @@ -6,7 +6,7 @@ namespace Bit.Api.Models.Request; public class SelectionReadOnlyRequestModel { [Required] - public string Id { get; set; } + public Guid Id { get; set; } public bool ReadOnly { get; set; } public bool HidePasswords { get; set; } public bool Manage { get; set; } @@ -15,7 +15,7 @@ public class SelectionReadOnlyRequestModel { return new CollectionAccessSelection { - Id = new Guid(Id), + Id = Id, ReadOnly = ReadOnly, HidePasswords = HidePasswords, Manage = Manage, diff --git a/src/Core/OrganizationFeatures/OrganizationCollections/BulkAddCollectionAccessCommand.cs b/src/Core/OrganizationFeatures/OrganizationCollections/BulkAddCollectionAccessCommand.cs new file mode 100644 index 0000000000..26f6682ed6 --- /dev/null +++ b/src/Core/OrganizationFeatures/OrganizationCollections/BulkAddCollectionAccessCommand.cs @@ -0,0 +1,95 @@ +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Models.Data; +using Bit.Core.OrganizationFeatures.OrganizationCollections.Interfaces; +using Bit.Core.Repositories; +using Bit.Core.Services; + +namespace Bit.Core.OrganizationFeatures.OrganizationCollections; + +public class BulkAddCollectionAccessCommand : IBulkAddCollectionAccessCommand +{ + private readonly ICollectionRepository _collectionRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IGroupRepository _groupRepository; + private readonly IEventService _eventService; + + public BulkAddCollectionAccessCommand( + ICollectionRepository collectionRepository, + IOrganizationUserRepository organizationUserRepository, + IGroupRepository groupRepository, + IEventService eventService) + { + _collectionRepository = collectionRepository; + _organizationUserRepository = organizationUserRepository; + _groupRepository = groupRepository; + _eventService = eventService; + } + + public async Task AddAccessAsync(ICollection collections, + ICollection users, + ICollection groups) + { + await ValidateRequestAsync(collections, users, groups); + + await _collectionRepository.CreateOrUpdateAccessForManyAsync( + collections.First().OrganizationId, + collections.Select(c => c.Id), + users, + groups + ); + + await _eventService.LogCollectionEventsAsync(collections.Select(c => + (c, EventType.Collection_Updated, (DateTime?)DateTime.UtcNow))); + } + + private async Task ValidateRequestAsync(ICollection collections, ICollection usersAccess, ICollection groupsAccess) + { + if (collections == null || collections.Count == 0) + { + throw new BadRequestException("No collections were provided."); + } + + var orgId = collections.First().OrganizationId; + + if (collections.Any(c => c.OrganizationId != orgId)) + { + throw new BadRequestException("All collections must belong to the same organization."); + } + + var collectionUserIds = usersAccess?.Select(u => u.Id).Distinct().ToList(); + + if (collectionUserIds is { Count: > 0 }) + { + var users = await _organizationUserRepository.GetManyAsync(collectionUserIds); + + if (users.Count != collectionUserIds.Count) + { + throw new BadRequestException("One or more users do not exist."); + } + + if (users.Any(u => u.OrganizationId != orgId)) + { + throw new BadRequestException("One or more users do not belong to the same organization as the collection being assigned."); + } + } + + var collectionGroupIds = groupsAccess?.Select(g => g.Id).Distinct().ToList(); + + if (collectionGroupIds is { Count: > 0 }) + { + var groups = await _groupRepository.GetManyByManyIds(collectionGroupIds); + + if (groups.Count != collectionGroupIds.Count) + { + throw new BadRequestException("One or more groups do not exist."); + } + + if (groups.Any(g => g.OrganizationId != orgId)) + { + throw new BadRequestException("One or more groups do not belong to the same organization as the collection being assigned."); + } + } + } +} diff --git a/src/Core/OrganizationFeatures/OrganizationCollections/Interfaces/IBulkAddCollectionAccessCommand.cs b/src/Core/OrganizationFeatures/OrganizationCollections/Interfaces/IBulkAddCollectionAccessCommand.cs new file mode 100644 index 0000000000..c8ac6d64d3 --- /dev/null +++ b/src/Core/OrganizationFeatures/OrganizationCollections/Interfaces/IBulkAddCollectionAccessCommand.cs @@ -0,0 +1,10 @@ +using Bit.Core.Entities; +using Bit.Core.Models.Data; + +namespace Bit.Core.OrganizationFeatures.OrganizationCollections.Interfaces; + +public interface IBulkAddCollectionAccessCommand +{ + Task AddAccessAsync(ICollection collections, + ICollection users, ICollection groups); +} diff --git a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs index 1756149236..7add19cfd4 100644 --- a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs +++ b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs @@ -98,6 +98,7 @@ public static class OrganizationServiceCollectionExtensions public static void AddOrganizationCollectionCommands(this IServiceCollection services) { services.AddScoped(); + services.AddScoped(); } private static void AddOrganizationGroupCommands(this IServiceCollection services) diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs index 2114a60a89..8b5cc8d7ec 100644 --- a/src/Core/Repositories/ICollectionRepository.cs +++ b/src/Core/Repositories/ICollectionRepository.cs @@ -20,4 +20,6 @@ public interface ICollectionRepository : IRepository Task UpdateUsersAsync(Guid id, IEnumerable users); Task> GetManyUsersByIdAsync(Guid id); Task DeleteManyAsync(IEnumerable collectionIds); + Task CreateOrUpdateAccessForManyAsync(Guid organizationId, IEnumerable collectionIds, + IEnumerable users, IEnumerable groups); } diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs index 9e6f6c0bcf..37949da464 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs @@ -252,6 +252,21 @@ public class CollectionRepository : Repository, ICollectionRep } } + public async Task CreateOrUpdateAccessForManyAsync(Guid organizationId, IEnumerable collectionIds, + IEnumerable users, IEnumerable groups) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var usersArray = users != null ? users.ToArrayTVP() : Enumerable.Empty().ToArrayTVP(); + var groupsArray = groups != null ? groups.ToArrayTVP() : Enumerable.Empty().ToArrayTVP(); + + var results = await connection.ExecuteAsync( + $"[{Schema}].[Collection_CreateOrUpdateAccessForMany]", + new { OrganizationId = organizationId, CollectionIds = collectionIds.ToGuidIdArrayTVP(), Users = usersArray, Groups = groupsArray }, + commandType: CommandType.StoredProcedure); + } + } + public async Task CreateUserAsync(Guid collectionId, Guid organizationUserId) { using (var connection = new SqlConnection(ConnectionString)) diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index f7e616459b..e1e7789214 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -473,6 +473,97 @@ public class CollectionRepository : Repository collectionIds, + IEnumerable users, IEnumerable groups) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + var collectionIdsList = collectionIds.ToList(); + + if (users != null) + { + var existingCollectionUsers = await dbContext.CollectionUsers + .Where(cu => collectionIdsList.Contains(cu.CollectionId)) + .ToDictionaryAsync(x => (x.CollectionId, x.OrganizationUserId)); + + var requestedUsers = users.ToList(); + + foreach (var collectionId in collectionIdsList) + { + foreach (var requestedUser in requestedUsers) + { + if (!existingCollectionUsers.TryGetValue( + (collectionId, requestedUser.Id), + out var existingCollectionUser) + ) + { + // This is a brand new entry + dbContext.CollectionUsers.Add(new CollectionUser + { + CollectionId = collectionId, + OrganizationUserId = requestedUser.Id, + HidePasswords = requestedUser.HidePasswords, + ReadOnly = requestedUser.ReadOnly, + Manage = requestedUser.Manage + }); + continue; + } + + // It already exists, update it + existingCollectionUser.HidePasswords = requestedUser.HidePasswords; + existingCollectionUser.ReadOnly = requestedUser.ReadOnly; + existingCollectionUser.Manage = requestedUser.Manage; + dbContext.CollectionUsers.Update(existingCollectionUser); + } + } + } + + if (groups != null) + { + var existingCollectionGroups = await dbContext.CollectionGroups + .Where(cu => collectionIdsList.Contains(cu.CollectionId)) + .ToDictionaryAsync(x => (x.CollectionId, x.GroupId)); + + var requestedGroups = groups.ToList(); + + foreach (var collectionId in collectionIdsList) + { + foreach (var requestedGroup in requestedGroups) + { + if (!existingCollectionGroups.TryGetValue( + (collectionId, requestedGroup.Id), + out var existingCollectionGroup) + ) + { + // This is a brand new entry + dbContext.CollectionGroups.Add(new CollectionGroup() + { + CollectionId = collectionId, + GroupId = requestedGroup.Id, + HidePasswords = requestedGroup.HidePasswords, + ReadOnly = requestedGroup.ReadOnly, + Manage = requestedGroup.Manage + }); + continue; + } + + // It already exists, update it + existingCollectionGroup.HidePasswords = requestedGroup.HidePasswords; + existingCollectionGroup.ReadOnly = requestedGroup.ReadOnly; + existingCollectionGroup.Manage = requestedGroup.Manage; + dbContext.CollectionGroups.Update(existingCollectionGroup); + } + } + } + // Need to save the new collection users/groups before running the bump revision code + await dbContext.SaveChangesAsync(); + await dbContext.UserBumpAccountRevisionDateByCollectionIdsAsync(collectionIdsList, organizationId); + await dbContext.SaveChangesAsync(); + } + } + private async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable groups) { var groupsInOrg = dbContext.Groups.Where(g => g.OrganizationId == collection.OrganizationId); diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContextExtensions.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContextExtensions.cs index 480855d61c..94d8bb52d8 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContextExtensions.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContextExtensions.cs @@ -74,6 +74,39 @@ public static class DatabaseContextExtensions UpdateUserRevisionDate(users); } + public static async Task UserBumpAccountRevisionDateByCollectionIdsAsync(this DatabaseContext context, IEnumerable collectionIds, Guid organizationId) + { + var query = from u in context.Users + from c in context.Collections + join ou in context.OrganizationUsers + on u.Id equals ou.UserId + join cu in context.CollectionUsers + on new { ou.AccessAll, OrganizationUserId = ou.Id, CollectionId = c.Id } equals + new { AccessAll = false, cu.OrganizationUserId, cu.CollectionId } into cu_g + from cu in cu_g.DefaultIfEmpty() + join gu in context.GroupUsers + on new { CollectionId = (Guid?)cu.CollectionId, ou.AccessAll, OrganizationUserId = ou.Id } equals + new { CollectionId = (Guid?)null, AccessAll = false, gu.OrganizationUserId } into gu_g + from gu in gu_g.DefaultIfEmpty() + join g in context.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in context.CollectionGroups + on new { g.AccessAll, gu.GroupId, CollectionId = c.Id } equals + new { AccessAll = false, cg.GroupId, cg.CollectionId } into cg_g + from cg in cg_g.DefaultIfEmpty() + where ou.OrganizationId == organizationId && collectionIds.Contains(c.Id) && + ou.Status == OrganizationUserStatusType.Confirmed && + (cu.CollectionId != null || + cg.CollectionId != null || + ou.AccessAll == true || + g.AccessAll == true) + select u; + + var users = await query.ToListAsync(); + UpdateUserRevisionDate(users); + } + public static async Task UserBumpAccountRevisionDateByOrganizationUserIdAsync(this DatabaseContext context, Guid organizationUserId) { var query = from u in context.Users diff --git a/src/Sql/dbo/Stored Procedures/Collection_CreateOrUpdateAccessForMany.sql b/src/Sql/dbo/Stored Procedures/Collection_CreateOrUpdateAccessForMany.sql new file mode 100644 index 0000000000..e7f860fa60 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/Collection_CreateOrUpdateAccessForMany.sql @@ -0,0 +1,97 @@ +CREATE PROCEDURE [dbo].[Collection_CreateOrUpdateAccessForMany] + @OrganizationId UNIQUEIDENTIFIER, + @CollectionIds AS [dbo].[GuidIdArray] READONLY, + @Groups AS [dbo].[SelectionReadOnlyArray] READONLY, + @Users AS [dbo].[SelectionReadOnlyArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + -- Groups + ;WITH [NewCollectionGroups] AS ( + SELECT + cId.[Id] AS [CollectionId], + cg.[Id] AS [GroupId], + cg.[ReadOnly], + cg.[HidePasswords], + cg.[Manage] + FROM + @Groups AS cg + CROSS JOIN -- Create a CollectionGroup record for every CollectionId + @CollectionIds cId + INNER JOIN + [dbo].[Group] g ON cg.[Id] = g.[Id] + WHERE + g.[OrganizationId] = @OrganizationId + ) + MERGE + [dbo].[CollectionGroup] as [Target] + USING + [NewCollectionGroups] AS [Source] + ON + [Target].[CollectionId] = [Source].[CollectionId] + AND [Target].[GroupId] = [Source].[GroupId] + -- Update the target if any values are different from the source + WHEN MATCHED AND EXISTS( + SELECT [Source].[ReadOnly], [Source].[HidePasswords], [Source].[Manage] + EXCEPT + SELECT [Target].[ReadOnly], [Target].[HidePasswords], [Target].[Manage] + ) THEN UPDATE SET + [Target].[ReadOnly] = [Source].[ReadOnly], + [Target].[HidePasswords] = [Source].[HidePasswords], + [Target].[Manage] = [Source].[Manage] + WHEN NOT MATCHED BY TARGET + THEN INSERT VALUES + ( + [Source].[CollectionId], + [Source].[GroupId], + [Source].[ReadOnly], + [Source].[HidePasswords], + [Source].[Manage] + ); + + -- Users + ;WITH [NewCollectionUsers] AS ( + SELECT + cId.[Id] AS [CollectionId], + cu.[Id] AS [OrganizationUserId], + cu.[ReadOnly], + cu.[HidePasswords], + cu.[Manage] + FROM + @Users AS cu + CROSS JOIN -- Create a CollectionUser record for every CollectionId + @CollectionIds cId + INNER JOIN + [dbo].[OrganizationUser] u ON cu.[Id] = u.[Id] + WHERE + u.[OrganizationId] = @OrganizationId + ) + MERGE + [dbo].[CollectionUser] as [Target] + USING + [NewCollectionUsers] AS [Source] + ON + [Target].[CollectionId] = [Source].[CollectionId] + AND [Target].[OrganizationUserId] = [Source].[OrganizationUserId] + -- Update the target if any values are different from the source + WHEN MATCHED AND EXISTS( + SELECT [Source].[ReadOnly], [Source].[HidePasswords], [Source].[Manage] + EXCEPT + SELECT [Target].[ReadOnly], [Target].[HidePasswords], [Target].[Manage] + ) THEN UPDATE SET + [Target].[ReadOnly] = [Source].[ReadOnly], + [Target].[HidePasswords] = [Source].[HidePasswords], + [Target].[Manage] = [Source].[Manage] + WHEN NOT MATCHED BY TARGET + THEN INSERT VALUES + ( + [Source].[CollectionId], + [Source].[OrganizationUserId], + [Source].[ReadOnly], + [Source].[HidePasswords], + [Source].[Manage] + ); + + EXEC [dbo].[User_BumpAccountRevisionDateByCollectionIds] @CollectionIds, @OrganizationId +END diff --git a/src/Sql/dbo/Stored Procedures/User_BumpAccountRevisionDateByCollectionIds.sql b/src/Sql/dbo/Stored Procedures/User_BumpAccountRevisionDateByCollectionIds.sql new file mode 100644 index 0000000000..d027708a63 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/User_BumpAccountRevisionDateByCollectionIds.sql @@ -0,0 +1,35 @@ +CREATE PROCEDURE [dbo].[User_BumpAccountRevisionDateByCollectionIds] + @CollectionIds AS [dbo].[GuidIdArray] READONLY, + @OrganizationId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON + +UPDATE + U +SET + U.[AccountRevisionDate] = GETUTCDATE() + FROM + [dbo].[User] U + INNER JOIN + [dbo].[Collection] C ON C.[Id] IN (SELECT [Id] FROM @CollectionIds) + INNER JOIN + [dbo].[OrganizationUser] OU ON OU.[UserId] = U.[Id] + LEFT JOIN + [dbo].[CollectionUser] CU ON OU.[AccessAll] = 0 AND CU.[OrganizationUserId] = OU.[Id] AND CU.[CollectionId] = C.[Id] + LEFT JOIN + [dbo].[GroupUser] GU ON CU.[CollectionId] IS NULL AND OU.[AccessAll] = 0 AND GU.[OrganizationUserId] = OU.[Id] + LEFT JOIN + [dbo].[Group] G ON G.[Id] = GU.[GroupId] + LEFT JOIN + [dbo].[CollectionGroup] CG ON G.[AccessAll] = 0 AND CG.[GroupId] = GU.[GroupId] AND CG.[CollectionId] = C.[Id] +WHERE + OU.[OrganizationId] = @OrganizationId + AND OU.[Status] = 2 -- 2 = Confirmed + AND ( + CU.[CollectionId] IS NOT NULL + OR CG.[CollectionId] IS NOT NULL + OR OU.[AccessAll] = 1 + OR G.[AccessAll] = 1 + ) +END diff --git a/test/Api.Test/Controllers/CollectionsControllerTests.cs b/test/Api.Test/Controllers/CollectionsControllerTests.cs index 564ab4e15b..d4e5aeac16 100644 --- a/test/Api.Test/Controllers/CollectionsControllerTests.cs +++ b/test/Api.Test/Controllers/CollectionsControllerTests.cs @@ -221,4 +221,120 @@ public class CollectionsControllerTests .DidNotReceiveWithAnyArgs() .DeleteManyAsync((IEnumerable)default); } + + [Theory, BitAutoData] + public async Task PostBulkCollectionAccess_Success(User actingUser, ICollection collections, SutProvider sutProvider) + { + // Arrange + var userId = Guid.NewGuid(); + var groupId = Guid.NewGuid(); + var model = new BulkCollectionAccessRequestModel + { + CollectionIds = collections.Select(c => c.Id), + Users = new[] { new SelectionReadOnlyRequestModel { Id = userId, Manage = true } }, + Groups = new[] { new SelectionReadOnlyRequestModel { Id = groupId, ReadOnly = true } }, + }; + + sutProvider.GetDependency() + .GetManyByManyIdsAsync(model.CollectionIds) + .Returns(collections); + + sutProvider.GetDependency() + .UserId.Returns(actingUser.Id); + + sutProvider.GetDependency().AuthorizeAsync( + Arg.Any(), ExpectedCollectionAccess(), + Arg.Is>( + r => r.Contains(CollectionOperations.ModifyAccess) + )) + .Returns(AuthorizationResult.Success()); + + IEnumerable ExpectedCollectionAccess() => Arg.Is>(cols => cols.SequenceEqual(collections)); + + // Act + await sutProvider.Sut.PostBulkCollectionAccess(model); + + // Assert + await sutProvider.GetDependency().Received().AuthorizeAsync( + Arg.Any(), + ExpectedCollectionAccess(), + Arg.Is>( + r => r.Contains(CollectionOperations.ModifyAccess)) + ); + await sutProvider.GetDependency().Received() + .AddAccessAsync( + Arg.Is>(g => g.SequenceEqual(collections)), + Arg.Is>(u => u.All(c => c.Id == userId && c.Manage)), + Arg.Is>(g => g.All(c => c.Id == groupId && c.ReadOnly))); + } + + [Theory, BitAutoData] + public async Task PostBulkCollectionAccess_CollectionsNotFound_Throws(User actingUser, ICollection collections, SutProvider sutProvider) + { + var userId = Guid.NewGuid(); + var groupId = Guid.NewGuid(); + var model = new BulkCollectionAccessRequestModel + { + CollectionIds = collections.Select(c => c.Id), + Users = new[] { new SelectionReadOnlyRequestModel { Id = userId, Manage = true } }, + Groups = new[] { new SelectionReadOnlyRequestModel { Id = groupId, ReadOnly = true } }, + }; + + sutProvider.GetDependency() + .UserId.Returns(actingUser.Id); + + sutProvider.GetDependency() + .GetManyByManyIdsAsync(model.CollectionIds) + .Returns(collections.Skip(1).ToList()); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.PostBulkCollectionAccess(model)); + + Assert.Equal("One or more collections not found.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().AuthorizeAsync( + Arg.Any(), + Arg.Any>(), + Arg.Any>() + ); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .AddAccessAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task PostBulkCollectionAccess_AccessDenied_Throws(User actingUser, ICollection collections, SutProvider sutProvider) + { + var userId = Guid.NewGuid(); + var groupId = Guid.NewGuid(); + var model = new BulkCollectionAccessRequestModel + { + CollectionIds = collections.Select(c => c.Id), + Users = new[] { new SelectionReadOnlyRequestModel { Id = userId, Manage = true } }, + Groups = new[] { new SelectionReadOnlyRequestModel { Id = groupId, ReadOnly = true } }, + }; + + sutProvider.GetDependency() + .UserId.Returns(actingUser.Id); + + sutProvider.GetDependency() + .GetManyByManyIdsAsync(model.CollectionIds) + .Returns(collections); + + sutProvider.GetDependency().AuthorizeAsync( + Arg.Any(), ExpectedCollectionAccess(), + Arg.Is>( + r => r.Contains(CollectionOperations.ModifyAccess) + )) + .Returns(AuthorizationResult.Failed()); + + IEnumerable ExpectedCollectionAccess() => Arg.Is>(cols => cols.SequenceEqual(collections)); + + await Assert.ThrowsAsync(() => sutProvider.Sut.PostBulkCollectionAccess(model)); + await sutProvider.GetDependency().Received().AuthorizeAsync( + Arg.Any(), + ExpectedCollectionAccess(), + Arg.Is>( + r => r.Contains(CollectionOperations.ModifyAccess)) + ); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .AddAccessAsync(default, default, default); + } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationCollections/BulkAddCollectionAccessCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationCollections/BulkAddCollectionAccessCommandTests.cs new file mode 100644 index 0000000000..349a1bd690 --- /dev/null +++ b/test/Core.Test/OrganizationFeatures/OrganizationCollections/BulkAddCollectionAccessCommandTests.cs @@ -0,0 +1,271 @@ +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Models.Data; +using Bit.Core.OrganizationFeatures.OrganizationCollections; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.Vault.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.OrganizationFeatures.OrganizationCollections; + +[SutProviderCustomize] +public class BulkAddCollectionAccessCommandTests +{ + [Theory, BitAutoData, CollectionCustomization] + public async Task AddAccessAsync_Success(SutProvider sutProvider, + Organization org, + ICollection collections, + ICollection organizationUsers, + ICollection groups, + IEnumerable collectionUsers, + IEnumerable collectionGroups) + { + sutProvider.GetDependency() + .GetManyAsync( + Arg.Is>(ids => ids.SequenceEqual(collectionUsers.Select(u => u.OrganizationUserId))) + ) + .Returns(organizationUsers); + + sutProvider.GetDependency() + .GetManyByManyIds( + Arg.Is>(ids => ids.SequenceEqual(collectionGroups.Select(u => u.GroupId))) + ) + .Returns(groups); + + var userAccessSelections = ToAccessSelection(collectionUsers); + var groupAccessSelections = ToAccessSelection(collectionGroups); + await sutProvider.Sut.AddAccessAsync(collections, + userAccessSelections, + groupAccessSelections + ); + + await sutProvider.GetDependency().Received().GetManyAsync( + Arg.Is>(ids => ids.SequenceEqual(userAccessSelections.Select(u => u.Id))) + ); + await sutProvider.GetDependency().Received().GetManyByManyIds( + Arg.Is>(ids => ids.SequenceEqual(groupAccessSelections.Select(g => g.Id))) + ); + + await sutProvider.GetDependency().Received().CreateOrUpdateAccessForManyAsync( + org.Id, + Arg.Is>(ids => ids.SequenceEqual(collections.Select(c => c.Id))), + userAccessSelections, + groupAccessSelections); + + await sutProvider.GetDependency().Received().LogCollectionEventsAsync( + Arg.Is>( + events => events.All(e => + collections.Contains(e.Item1) && + e.Item2 == EventType.Collection_Updated && + e.Item3.HasValue + ) + ) + ); + } + + [Theory, BitAutoData, CollectionCustomization] + public async Task ValidateRequestAsync_NoCollectionsProvided_Failure(SutProvider sutProvider) + { + var exception = + await Assert.ThrowsAsync( + () => sutProvider.Sut.AddAccessAsync(null, null, null)); + + Assert.Contains("No collections were provided.", exception.Message); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyByManyIdsAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyByManyIds(default); + } + + + [Theory, BitAutoData, CollectionCustomization] + public async Task ValidateRequestAsync_NoCollection_Failure(SutProvider sutProvider, + IEnumerable collectionUsers, + IEnumerable collectionGroups) + { + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.AddAccessAsync(Enumerable.Empty().ToList(), + ToAccessSelection(collectionUsers), + ToAccessSelection(collectionGroups) + )); + + Assert.Contains("No collections were provided.", exception.Message); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyByManyIds(default); + } + + [Theory, BitAutoData, CollectionCustomization] + public async Task ValidateRequestAsync_DifferentOrgs_Failure(SutProvider sutProvider, + ICollection collections, + IEnumerable collectionUsers, + IEnumerable collectionGroups) + { + collections.First().OrganizationId = Guid.NewGuid(); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.AddAccessAsync(collections, + ToAccessSelection(collectionUsers), + ToAccessSelection(collectionGroups) + )); + + Assert.Contains("All collections must belong to the same organization.", exception.Message); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyByManyIds(default); + } + + [Theory, BitAutoData, CollectionCustomization] + public async Task ValidateRequestAsync_MissingUser_Failure(SutProvider sutProvider, + IList collections, + IList organizationUsers, + IEnumerable collectionUsers, + IEnumerable collectionGroups) + { + organizationUsers.RemoveAt(0); + + sutProvider.GetDependency() + .GetManyAsync( + Arg.Is>(ids => ids.SequenceEqual(collectionUsers.Select(u => u.OrganizationUserId))) + ) + .Returns(organizationUsers); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.AddAccessAsync(collections, + ToAccessSelection(collectionUsers), + ToAccessSelection(collectionGroups) + )); + + Assert.Contains("One or more users do not exist.", exception.Message); + + await sutProvider.GetDependency().Received().GetManyAsync( + Arg.Is>(ids => ids.SequenceEqual(collectionUsers.Select(u => u.OrganizationUserId))) + ); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyByManyIds(default); + } + + [Theory, BitAutoData, CollectionCustomization] + public async Task ValidateRequestAsync_UserWrongOrg_Failure(SutProvider sutProvider, + IList collections, + IList organizationUsers, + IEnumerable collectionUsers, + IEnumerable collectionGroups) + { + organizationUsers.First().OrganizationId = Guid.NewGuid(); + + sutProvider.GetDependency() + .GetManyAsync( + Arg.Is>(ids => ids.SequenceEqual(collectionUsers.Select(u => u.OrganizationUserId))) + ) + .Returns(organizationUsers); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.AddAccessAsync(collections, + ToAccessSelection(collectionUsers), + ToAccessSelection(collectionGroups) + )); + + Assert.Contains("One or more users do not belong to the same organization as the collection being assigned.", exception.Message); + + await sutProvider.GetDependency().Received().GetManyAsync( + Arg.Is>(ids => ids.SequenceEqual(collectionUsers.Select(u => u.OrganizationUserId))) + ); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyByManyIds(default); + } + + [Theory, BitAutoData, CollectionCustomization] + public async Task ValidateRequestAsync_MissingGroup_Failure(SutProvider sutProvider, + IList collections, + IList organizationUsers, + IList groups, + IEnumerable collectionUsers, + IEnumerable collectionGroups) + { + groups.RemoveAt(0); + + sutProvider.GetDependency() + .GetManyAsync( + Arg.Is>(ids => ids.SequenceEqual(collectionUsers.Select(u => u.OrganizationUserId))) + ) + .Returns(organizationUsers); + + sutProvider.GetDependency() + .GetManyByManyIds( + Arg.Is>(ids => ids.SequenceEqual(collectionGroups.Select(u => u.GroupId))) + ) + .Returns(groups); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.AddAccessAsync(collections, + ToAccessSelection(collectionUsers), + ToAccessSelection(collectionGroups) + )); + + Assert.Contains("One or more groups do not exist.", exception.Message); + + await sutProvider.GetDependency().Received().GetManyAsync( + Arg.Is>(ids => ids.SequenceEqual(collectionUsers.Select(u => u.OrganizationUserId))) + ); + await sutProvider.GetDependency().Received().GetManyByManyIds( + Arg.Is>(ids => ids.SequenceEqual(collectionGroups.Select(u => u.GroupId))) + ); + } + + [Theory, BitAutoData, CollectionCustomization] + public async Task ValidateRequestAsync_GroupWrongOrg_Failure(SutProvider sutProvider, + IList collections, + IList organizationUsers, + IList groups, + IEnumerable collectionUsers, + IEnumerable collectionGroups) + { + groups.First().OrganizationId = Guid.NewGuid(); + + sutProvider.GetDependency() + .GetManyAsync( + Arg.Is>(ids => ids.SequenceEqual(collectionUsers.Select(u => u.OrganizationUserId))) + ) + .Returns(organizationUsers); + + sutProvider.GetDependency() + .GetManyByManyIds( + Arg.Is>(ids => ids.SequenceEqual(collectionGroups.Select(u => u.GroupId))) + ) + .Returns(groups); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.AddAccessAsync(collections, + ToAccessSelection(collectionUsers), + ToAccessSelection(collectionGroups) + )); + + Assert.Contains("One or more groups do not belong to the same organization as the collection being assigned.", exception.Message); + + await sutProvider.GetDependency().Received().GetManyAsync( + Arg.Is>(ids => ids.SequenceEqual(collectionUsers.Select(u => u.OrganizationUserId))) + ); + await sutProvider.GetDependency().Received().GetManyByManyIds( + Arg.Is>(ids => ids.SequenceEqual(collectionGroups.Select(u => u.GroupId))) + ); + } + + private static ICollection ToAccessSelection(IEnumerable collectionUsers) + { + return collectionUsers.Select(cu => new CollectionAccessSelection + { + Id = cu.OrganizationUserId, + Manage = cu.Manage, + HidePasswords = cu.HidePasswords, + ReadOnly = cu.ReadOnly + }).ToList(); + } + private static ICollection ToAccessSelection(IEnumerable collectionGroups) + { + return collectionGroups.Select(cg => new CollectionAccessSelection + { + Id = cg.GroupId, + Manage = cg.Manage, + HidePasswords = cg.HidePasswords, + ReadOnly = cg.ReadOnly + }).ToList(); + } +} diff --git a/test/Core.Test/Vault/AutoFixture/CollectionFixture.cs b/test/Core.Test/Vault/AutoFixture/CollectionFixture.cs index e74cdac1c0..3a84d3438f 100644 --- a/test/Core.Test/Vault/AutoFixture/CollectionFixture.cs +++ b/test/Core.Test/Vault/AutoFixture/CollectionFixture.cs @@ -17,6 +17,9 @@ public class CollectionCustomization : ICustomization { var orgId = Guid.NewGuid(); + fixture.Customize(composer => composer + .With(o => o.Id, orgId)); + fixture.Customize(composer => composer .With(o => o.Id, orgId)); diff --git a/util/Migrator/DbScripts/2023-08-25_00_BulkAddCollectionAccess.sql b/util/Migrator/DbScripts/2023-08-25_00_BulkAddCollectionAccess.sql new file mode 100644 index 0000000000..d45d7dc5bd --- /dev/null +++ b/util/Migrator/DbScripts/2023-08-25_00_BulkAddCollectionAccess.sql @@ -0,0 +1,135 @@ +CREATE OR ALTER PROCEDURE [dbo].[User_BumpAccountRevisionDateByCollectionIds] + @CollectionIds AS [dbo].[GuidIdArray] READONLY, + @OrganizationId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON + +UPDATE + U +SET + U.[AccountRevisionDate] = GETUTCDATE() + FROM + [dbo].[User] U + INNER JOIN + [dbo].[Collection] C ON C.[Id] IN (SELECT [Id] FROM @CollectionIds) + INNER JOIN + [dbo].[OrganizationUser] OU ON OU.[UserId] = U.[Id] + LEFT JOIN + [dbo].[CollectionUser] CU ON OU.[AccessAll] = 0 AND CU.[OrganizationUserId] = OU.[Id] AND CU.[CollectionId] = C.[Id] + LEFT JOIN + [dbo].[GroupUser] GU ON CU.[CollectionId] IS NULL AND OU.[AccessAll] = 0 AND GU.[OrganizationUserId] = OU.[Id] + LEFT JOIN + [dbo].[Group] G ON G.[Id] = GU.[GroupId] + LEFT JOIN + [dbo].[CollectionGroup] CG ON G.[AccessAll] = 0 AND CG.[GroupId] = GU.[GroupId] AND CG.[CollectionId] = C.[Id] +WHERE + OU.[OrganizationId] = @OrganizationId + AND OU.[Status] = 2 -- 2 = Confirmed + AND ( + CU.[CollectionId] IS NOT NULL + OR CG.[CollectionId] IS NOT NULL + OR OU.[AccessAll] = 1 + OR G.[AccessAll] = 1 + ) +END +GO + +CREATE OR ALTER PROCEDURE [dbo].[Collection_CreateOrUpdateAccessForMany] + @OrganizationId UNIQUEIDENTIFIER, + @CollectionIds AS [dbo].[GuidIdArray] READONLY, + @Groups AS [dbo].[SelectionReadOnlyArray] READONLY, + @Users AS [dbo].[SelectionReadOnlyArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + -- Groups + ;WITH [NewCollectionGroups] AS ( + SELECT + cId.[Id] AS [CollectionId], + cg.[Id] AS [GroupId], + cg.[ReadOnly], + cg.[HidePasswords], + cg.[Manage] + FROM + @Groups AS cg + CROSS JOIN -- Create a CollectionGroup record for every CollectionId + @CollectionIds cId + INNER JOIN + [dbo].[Group] g ON cg.[Id] = g.[Id] + WHERE + g.[OrganizationId] = @OrganizationId + ) + MERGE + [dbo].[CollectionGroup] as [Target] + USING + [NewCollectionGroups] AS [Source] + ON + [Target].[CollectionId] = [Source].[CollectionId] + AND [Target].[GroupId] = [Source].[GroupId] + -- Update the target if any values are different from the source + WHEN MATCHED AND EXISTS( + SELECT [Source].[ReadOnly], [Source].[HidePasswords], [Source].[Manage] + EXCEPT + SELECT [Target].[ReadOnly], [Target].[HidePasswords], [Target].[Manage] + ) THEN UPDATE SET + [Target].[ReadOnly] = [Source].[ReadOnly], + [Target].[HidePasswords] = [Source].[HidePasswords], + [Target].[Manage] = [Source].[Manage] + WHEN NOT MATCHED BY TARGET + THEN INSERT VALUES + ( + [Source].[CollectionId], + [Source].[GroupId], + [Source].[ReadOnly], + [Source].[HidePasswords], + [Source].[Manage] + ); + + -- Users + ;WITH [NewCollectionUsers] AS ( + SELECT + cId.[Id] AS [CollectionId], + cu.[Id] AS [OrganizationUserId], + cu.[ReadOnly], + cu.[HidePasswords], + cu.[Manage] + FROM + @Users AS cu + CROSS JOIN -- Create a CollectionUser record for every CollectionId + @CollectionIds cId + INNER JOIN + [dbo].[OrganizationUser] u ON cu.[Id] = u.[Id] + WHERE + u.[OrganizationId] = @OrganizationId + ) + MERGE + [dbo].[CollectionUser] as [Target] + USING + [NewCollectionUsers] AS [Source] + ON + [Target].[CollectionId] = [Source].[CollectionId] + AND [Target].[OrganizationUserId] = [Source].[OrganizationUserId] + -- Update the target if any values are different from the source + WHEN MATCHED AND EXISTS( + SELECT [Source].[ReadOnly], [Source].[HidePasswords], [Source].[Manage] + EXCEPT + SELECT [Target].[ReadOnly], [Target].[HidePasswords], [Target].[Manage] + ) THEN UPDATE SET + [Target].[ReadOnly] = [Source].[ReadOnly], + [Target].[HidePasswords] = [Source].[HidePasswords], + [Target].[Manage] = [Source].[Manage] + WHEN NOT MATCHED BY TARGET + THEN INSERT VALUES + ( + [Source].[CollectionId], + [Source].[OrganizationUserId], + [Source].[ReadOnly], + [Source].[HidePasswords], + [Source].[Manage] + ); + + EXEC [dbo].[User_BumpAccountRevisionDateByCollectionIds] @CollectionIds, @OrganizationId +END +GO