diff --git a/src/Api/Controllers/CollectionsController.cs b/src/Api/Controllers/CollectionsController.cs index 419ee8d816..e4010d0018 100644 --- a/src/Api/Controllers/CollectionsController.cs +++ b/src/Api/Controllers/CollectionsController.cs @@ -161,7 +161,7 @@ public class CollectionsController : Controller var groups = model.Groups?.Select(g => g.ToSelectionReadOnly()); var users = model.Users?.Select(g => g.ToSelectionReadOnly()); - await _collectionService.SaveAsync(collection, groups, users, _currentContext.UserId); + await _collectionService.SaveAsync(collection, groups, users); return new CollectionResponseModel(collection); } diff --git a/src/Core/Services/ICollectionService.cs b/src/Core/Services/ICollectionService.cs index 931993dacb..4d392a7722 100644 --- a/src/Core/Services/ICollectionService.cs +++ b/src/Core/Services/ICollectionService.cs @@ -5,7 +5,7 @@ namespace Bit.Core.Services; public interface ICollectionService { - Task SaveAsync(Collection collection, IEnumerable groups = null, IEnumerable users = null, Guid? assignUserId = null); + Task SaveAsync(Collection collection, IEnumerable groups = null, IEnumerable users = null); Task DeleteUserAsync(Collection collection, Guid organizationUserId); Task> GetOrganizationCollectionsAsync(Guid organizationId); } diff --git a/src/Core/Services/Implementations/CollectionService.cs b/src/Core/Services/Implementations/CollectionService.cs index 6525fdc210..b2beccbbce 100644 --- a/src/Core/Services/Implementations/CollectionService.cs +++ b/src/Core/Services/Implementations/CollectionService.cs @@ -41,7 +41,7 @@ public class CollectionService : ICollectionService } public async Task SaveAsync(Collection collection, IEnumerable groups = null, - IEnumerable users = null, Guid? assignUserId = null) + IEnumerable users = null) { var org = await _organizationRepository.GetByIdAsync(collection.OrganizationId); if (org == null) @@ -49,6 +49,16 @@ public class CollectionService : ICollectionService throw new BadRequestException("Organization not found"); } + var groupsList = groups?.ToList(); + var usersList = users?.ToList(); + var groupHasManageAccess = groupsList?.Any(g => g.Manage) ?? false; + var userHasManageAccess = usersList?.Any(u => u.Manage) ?? false; + if (!groupHasManageAccess && !userHasManageAccess) + { + throw new BadRequestException( + "At least one member or group must have can manage permission."); + } + if (collection.Id == default(Guid)) { if (org.MaxCollections.HasValue) @@ -61,26 +71,13 @@ public class CollectionService : ICollectionService } } - await _collectionRepository.CreateAsync(collection, org.UseGroups ? groups : null, users); - - // Assign a user to the newly created collection. - if (assignUserId.HasValue) - { - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(org.Id, assignUserId.Value); - if (orgUser != null && orgUser.Status == Enums.OrganizationUserStatusType.Confirmed) - { - await _collectionRepository.UpdateUsersAsync(collection.Id, - new List { - new CollectionAccessSelection { Id = orgUser.Id, Manage = true} }); - } - } - + await _collectionRepository.CreateAsync(collection, org.UseGroups ? groupsList : null, usersList); await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Created); await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.CollectionCreated, org, _currentContext)); } else { - await _collectionRepository.ReplaceAsync(collection, org.UseGroups ? groups : null, users); + await _collectionRepository.ReplaceAsync(collection, org.UseGroups ? groupsList : null, usersList); await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Updated); } } diff --git a/test/Api.Test/Controllers/CollectionsControllerTests.cs b/test/Api.Test/Controllers/CollectionsControllerTests.cs index d4e5aeac16..3bfaa8b02c 100644 --- a/test/Api.Test/Controllers/CollectionsControllerTests.cs +++ b/test/Api.Test/Controllers/CollectionsControllerTests.cs @@ -40,7 +40,7 @@ public class CollectionsControllerTests await sutProvider.GetDependency() .Received(1) .SaveAsync(Arg.Any(), Arg.Any>(), - Arg.Any>(), null); + Arg.Any>()); } [Theory, BitAutoData] diff --git a/test/Core.Test/AutoFixture/CollectionAccessSelectionFixtures.cs b/test/Core.Test/AutoFixture/CollectionAccessSelectionFixtures.cs new file mode 100644 index 0000000000..54b7fb034f --- /dev/null +++ b/test/Core.Test/AutoFixture/CollectionAccessSelectionFixtures.cs @@ -0,0 +1,37 @@ +using System.Reflection; +using AutoFixture; +using AutoFixture.Xunit2; +using Bit.Core.Models.Data; + +namespace Bit.Core.Test.AutoFixture; + +public class CollectionAccessSelectionCustomization : ICustomization +{ + public bool Manage { get; set; } + + public CollectionAccessSelectionCustomization(bool manage) + { + Manage = manage; + } + + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.Manage, Manage)); + } +} + +public class CollectionAccessSelectionCustomizeAttribute : CustomizeAttribute +{ + private readonly bool _manage; + + public CollectionAccessSelectionCustomizeAttribute(bool manage = false) + { + _manage = manage; + } + + public override ICustomization GetCustomization(ParameterInfo parameter) + { + return new CollectionAccessSelectionCustomization(_manage); + } +} diff --git a/test/Core.Test/Services/CollectionServiceTests.cs b/test/Core.Test/Services/CollectionServiceTests.cs index d5b5f15ccd..0ce0a90dc4 100644 --- a/test/Core.Test/Services/CollectionServiceTests.cs +++ b/test/Core.Test/Services/CollectionServiceTests.cs @@ -5,6 +5,7 @@ using Bit.Core.Exceptions; using Bit.Core.Models.Data; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.AutoFixture; using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -18,23 +19,7 @@ namespace Bit.Core.Test.Services; public class CollectionServiceTest { [Theory, BitAutoData] - public async Task SaveAsync_DefaultId_CreatesCollectionInTheRepository(Collection collection, Organization organization, SutProvider sutProvider) - { - collection.Id = default; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var utcNow = DateTime.UtcNow; - - await sutProvider.Sut.SaveAsync(collection); - - await sutProvider.GetDependency().Received().CreateAsync(collection, null, null); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Created); - Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory, BitAutoData] - public async Task SaveAsync_DefaultIdWithUsers_CreatesCollectionInTheRepository(Collection collection, Organization organization, IEnumerable users, SutProvider sutProvider) + public async Task SaveAsync_DefaultIdWithUsers_CreatesCollectionInTheRepository(Collection collection, Organization organization, [CollectionAccessSelectionCustomize(true)] IEnumerable users, SutProvider sutProvider) { collection.Id = default; sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); @@ -42,7 +27,9 @@ public class CollectionServiceTest await sutProvider.Sut.SaveAsync(collection, null, users); - await sutProvider.GetDependency().Received().CreateAsync(collection, null, users); + await sutProvider.GetDependency().Received() + .CreateAsync(collection, Arg.Is>(l => l == null), + Arg.Is>(l => l.Any(i => i.Manage == true))); await sutProvider.GetDependency().Received() .LogCollectionEventAsync(collection, EventType.Collection_Created); Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); @@ -51,7 +38,7 @@ public class CollectionServiceTest [Theory, BitAutoData] public async Task SaveAsync_DefaultIdWithGroupsAndUsers_CreateCollectionWithGroupsAndUsersInRepository(Collection collection, - IEnumerable groups, IEnumerable users, Organization organization, SutProvider sutProvider) + [CollectionAccessSelectionCustomize(true)] IEnumerable groups, IEnumerable users, Organization organization, SutProvider sutProvider) { collection.Id = default; organization.UseGroups = true; @@ -60,7 +47,9 @@ public class CollectionServiceTest await sutProvider.Sut.SaveAsync(collection, groups, users); - await sutProvider.GetDependency().Received().CreateAsync(collection, groups, users); + await sutProvider.GetDependency().Received() + .CreateAsync(collection, Arg.Is>(l => l.Any(i => i.Manage == true)), + Arg.Any>()); await sutProvider.GetDependency().Received() .LogCollectionEventAsync(collection, EventType.Collection_Created); Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); @@ -68,15 +57,17 @@ public class CollectionServiceTest } [Theory, BitAutoData] - public async Task SaveAsync_NonDefaultId_ReplacesCollectionInRepository(Collection collection, Organization organization, SutProvider sutProvider) + public async Task SaveAsync_NonDefaultId_ReplacesCollectionInRepository(Collection collection, Organization organization, [CollectionAccessSelectionCustomize(true)] IEnumerable users, SutProvider sutProvider) { var creationDate = collection.CreationDate; sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection); + await sutProvider.Sut.SaveAsync(collection, null, users); - await sutProvider.GetDependency().Received().ReplaceAsync(collection, null, null); + await sutProvider.GetDependency().Received().ReplaceAsync(collection, + Arg.Is>(l => l == null), + Arg.Is>(l => l.Any(i => i.Manage == true))); await sutProvider.GetDependency().Received() .LogCollectionEventAsync(collection, EventType.Collection_Updated); Assert.Equal(collection.CreationDate, creationDate); @@ -84,39 +75,20 @@ public class CollectionServiceTest } [Theory, BitAutoData] - public async Task SaveAsync_OrganizationNotUseGroup_CreateCollectionWithoutGroupsInRepository(Collection collection, IEnumerable groups, + public async Task SaveAsync_OrganizationNotUseGroup_CreateCollectionWithoutGroupsInRepository(Collection collection, + IEnumerable groups, [CollectionAccessSelectionCustomize(true)] IEnumerable users, Organization organization, SutProvider sutProvider) { collection.Id = default; + organization.UseGroups = false; sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection, groups); + await sutProvider.Sut.SaveAsync(collection, groups, users); - await sutProvider.GetDependency().Received().CreateAsync(collection, null, null); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Created); - Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory, BitAutoData] - public async Task SaveAsync_DefaultIdWithUserId_UpdateUserInCollectionRepository(Collection collection, - Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) - { - collection.Id = default; - organizationUser.Status = OrganizationUserStatusType.Confirmed; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByOrganizationAsync(organization.Id, organizationUser.Id) - .Returns(organizationUser); - var utcNow = DateTime.UtcNow; - - await sutProvider.Sut.SaveAsync(collection, null, null, organizationUser.Id); - - await sutProvider.GetDependency().Received().CreateAsync(collection, null, null); - await sutProvider.GetDependency().Received() - .GetByOrganizationAsync(organization.Id, organizationUser.Id); - await sutProvider.GetDependency().Received().UpdateUsersAsync(collection.Id, Arg.Any>()); + await sutProvider.GetDependency().Received().CreateAsync(collection, + Arg.Is>(l => l == null), + Arg.Is>(l => l.Any(i => i.Manage == true))); await sutProvider.GetDependency().Received() .LogCollectionEventAsync(collection, EventType.Collection_Created); Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); @@ -135,14 +107,31 @@ public class CollectionServiceTest } [Theory, BitAutoData] - public async Task SaveAsync_ExceedsOrganizationMaxCollections_ThrowsBadRequest(Collection collection, Organization organization, SutProvider sutProvider) + public async Task SaveAsync_NoManageAccess_ThrowsBadRequest(Collection collection, Organization organization, + [CollectionAccessSelectionCustomize] IEnumerable users, SutProvider sutProvider) + { + collection.Id = default; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection, null, users)); + Assert.Contains("At least one member or group must have can manage permission.", ex.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default, default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCollectionEventAsync(default, default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ExceedsOrganizationMaxCollections_ThrowsBadRequest(Collection collection, + Organization organization, [CollectionAccessSelectionCustomize(true)] IEnumerable users, + SutProvider sutProvider) { collection.Id = default; sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency().GetCountByOrganizationIdAsync(organization.Id) .Returns(organization.MaxCollections.Value); - var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection)); + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection, null, users)); Assert.Equal($@"You have reached the maximum number of collections ({organization.MaxCollections.Value}) for this organization.", ex.Message); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default, default, default);