From e2d644f1365f93b43d51ac964e37b75fe6c6368a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Thu, 23 Nov 2023 12:21:20 +0000 Subject: [PATCH] [AC-1116] Assign new imported collections to the importing user with Manage permission (#3424) * [AC-1116] Assigning imported collections to the importing user with Manage permission * [AC-1116] Added unit tests --- .../Vault/Repositories/ICipherRepository.cs | 2 +- .../Services/Implementations/CipherService.cs | 23 +++++-- .../Vault/Repositories/CipherRepository.cs | 59 +++++++++++++++++- .../Vault/Repositories/CipherRepository.cs | 12 +++- .../Vault/Services/CipherServiceTests.cs | 61 +++++++++++++++++++ 5 files changed, 148 insertions(+), 9 deletions(-) diff --git a/src/Core/Vault/Repositories/ICipherRepository.cs b/src/Core/Vault/Repositories/ICipherRepository.cs index 401add13da..0ba80857d6 100644 --- a/src/Core/Vault/Repositories/ICipherRepository.cs +++ b/src/Core/Vault/Repositories/ICipherRepository.cs @@ -32,7 +32,7 @@ public interface ICipherRepository : IRepository Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers); Task CreateAsync(IEnumerable ciphers, IEnumerable folders); Task CreateAsync(IEnumerable ciphers, IEnumerable collections, - IEnumerable collectionCiphers); + IEnumerable collectionCiphers, IEnumerable collectionUsers); Task SoftDeleteAsync(IEnumerable ids, Guid userId); Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); Task RestoreAsync(IEnumerable ids, Guid userId); diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index 72437ec1b6..6e5b15de0d 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -27,6 +27,7 @@ public class CipherService : ICipherService private readonly ICollectionRepository _collectionRepository; private readonly IUserRepository _userRepository; private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; private readonly ICollectionCipherRepository _collectionCipherRepository; private readonly IPushNotificationService _pushService; private readonly IAttachmentStorageService _attachmentStorageService; @@ -34,7 +35,7 @@ public class CipherService : ICipherService private readonly IUserService _userService; private readonly IPolicyService _policyService; private readonly GlobalSettings _globalSettings; - private const long _fileSizeLeeway = 1024L * 1024L; // 1MB + private const long _fileSizeLeeway = 1024L * 1024L; // 1MB private readonly IReferenceEventService _referenceEventService; private readonly ICurrentContext _currentContext; @@ -44,6 +45,7 @@ public class CipherService : ICipherService ICollectionRepository collectionRepository, IUserRepository userRepository, IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, ICollectionCipherRepository collectionCipherRepository, IPushNotificationService pushService, IAttachmentStorageService attachmentStorageService, @@ -59,6 +61,7 @@ public class CipherService : ICipherService _collectionRepository = collectionRepository; _userRepository = userRepository; _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; _collectionCipherRepository = collectionCipherRepository; _pushService = pushService; _attachmentStorageService = attachmentStorageService; @@ -652,7 +655,7 @@ public class CipherService : ICipherService cipher.RevisionDate = DateTime.UtcNow; - // The sprocs will validate that all collections belong to this org/user and that they have + // The sprocs will validate that all collections belong to this org/user and that they have // proper write permissions. if (orgAdmin) { @@ -747,6 +750,7 @@ public class CipherService : ICipherService var org = collections.Count > 0 ? await _organizationRepository.GetByIdAsync(collections[0].OrganizationId) : await _organizationRepository.GetByIdAsync(ciphers.FirstOrDefault(c => c.OrganizationId.HasValue).OrganizationId.Value); + var importingOrgUser = await _organizationUserRepository.GetByOrganizationAsync(org.Id, importingUserId); if (collections.Count > 0 && org != null && org.MaxCollections.HasValue) { @@ -764,18 +768,25 @@ public class CipherService : ICipherService cipher.SetNewId(); } - var userCollectionsIds = (await _collectionRepository.GetManyByOrganizationIdAsync(org.Id)).Select(c => c.Id).ToList(); + var organizationCollectionsIds = (await _collectionRepository.GetManyByOrganizationIdAsync(org.Id)).Select(c => c.Id).ToList(); //Assign id to the ones that don't exist in DB //Need to keep the list order to create the relationships - List newCollections = new List(); + var newCollections = new List(); + var newCollectionUsers = new List(); foreach (var collection in collections) { - if (!userCollectionsIds.Contains(collection.Id)) + if (!organizationCollectionsIds.Contains(collection.Id)) { collection.SetNewId(); newCollections.Add(collection); + newCollectionUsers.Add(new CollectionUser + { + CollectionId = collection.Id, + OrganizationUserId = importingOrgUser.Id, + Manage = true + }); } } @@ -799,7 +810,7 @@ public class CipherService : ICipherService } // Create it all - await _cipherRepository.CreateAsync(ciphers, newCollections, collectionCiphers); + await _cipherRepository.CreateAsync(ciphers, newCollections, collectionCiphers, newCollectionUsers); // push await _pushService.PushSyncVaultAsync(importingUserId); diff --git a/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs index dfc62f3049..4f6c8e25f7 100644 --- a/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs @@ -589,7 +589,7 @@ public class CipherRepository : Repository, ICipherRepository } public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, - IEnumerable collectionCiphers) + IEnumerable collectionCiphers, IEnumerable collectionUsers) { if (!ciphers.Any()) { @@ -631,6 +631,16 @@ public class CipherRepository : Repository, ICipherRepository } } + if (collectionUsers.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[CollectionUser]"; + var dataTable = BuildCollectionUsersTable(bulkCopy, collectionUsers); + bulkCopy.WriteToServer(dataTable); + } + } + await connection.ExecuteAsync( $"[{Schema}].[User_BumpAccountRevisionDateByOrganizationId]", new { OrganizationId = ciphers.First().OrganizationId }, @@ -896,6 +906,53 @@ public class CipherRepository : Repository, ICipherRepository return collectionCiphersTable; } + private DataTable BuildCollectionUsersTable(SqlBulkCopy bulkCopy, IEnumerable collectionUsers) + { + var cu = collectionUsers.FirstOrDefault(); + if (cu == null) + { + throw new ApplicationException("Must have some collectionUsers to bulk import."); + } + + var collectionUsersTable = new DataTable("CollectionUserDataTable"); + + var collectionIdColumn = new DataColumn(nameof(cu.CollectionId), cu.CollectionId.GetType()); + collectionUsersTable.Columns.Add(collectionIdColumn); + var organizationUserIdColumn = new DataColumn(nameof(cu.OrganizationUserId), cu.OrganizationUserId.GetType()); + collectionUsersTable.Columns.Add(organizationUserIdColumn); + var readOnlyColumn = new DataColumn(nameof(cu.ReadOnly), cu.ReadOnly.GetType()); + collectionUsersTable.Columns.Add(readOnlyColumn); + var hidePasswordsColumn = new DataColumn(nameof(cu.HidePasswords), cu.HidePasswords.GetType()); + collectionUsersTable.Columns.Add(hidePasswordsColumn); + var manageColumn = new DataColumn(nameof(cu.Manage), cu.Manage.GetType()); + collectionUsersTable.Columns.Add(manageColumn); + + foreach (DataColumn col in collectionUsersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[2]; + keys[0] = collectionIdColumn; + keys[1] = organizationUserIdColumn; + collectionUsersTable.PrimaryKey = keys; + + foreach (var collectionUser in collectionUsers) + { + var row = collectionUsersTable.NewRow(); + + row[collectionIdColumn] = collectionUser.CollectionId; + row[organizationUserIdColumn] = collectionUser.OrganizationUserId; + row[readOnlyColumn] = collectionUser.ReadOnly; + row[hidePasswordsColumn] = collectionUser.HidePasswords; + row[manageColumn] = collectionUser.Manage; + + collectionUsersTable.Rows.Add(row); + } + + return collectionUsersTable; + } + private DataTable BuildSendsTable(SqlBulkCopy bulkCopy, IEnumerable sends) { var s = sends.FirstOrDefault(); diff --git a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs index c575838362..573e4fc3bd 100644 --- a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs @@ -161,7 +161,10 @@ public class CipherRepository : Repository ciphers, IEnumerable collections, IEnumerable collectionCiphers) + public async Task CreateAsync(IEnumerable ciphers, + IEnumerable collections, + IEnumerable collectionCiphers, + IEnumerable collectionUsers) { if (!ciphers.Any()) { @@ -184,6 +187,13 @@ public class CipherRepository : Repository>(collectionCiphers); await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionCipherEntities); } + + if (collectionUsers.Any()) + { + var collectionUserEntities = Mapper.Map>(collectionUsers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionUserEntities); + } + await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(ciphers.First().OrganizationId.Value); await dbContext.SaveChangesAsync(); } diff --git a/test/Core.Test/Vault/Services/CipherServiceTests.cs b/test/Core.Test/Vault/Services/CipherServiceTests.cs index ea3309fc69..d299084c9c 100644 --- a/test/Core.Test/Vault/Services/CipherServiceTests.cs +++ b/test/Core.Test/Vault/Services/CipherServiceTests.cs @@ -4,6 +4,9 @@ using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Test.AutoFixture.CipherFixtures; +using Bit.Core.Tools.Enums; +using Bit.Core.Tools.Models.Business; +using Bit.Core.Tools.Services; using Bit.Core.Utilities; using Bit.Core.Vault.Entities; using Bit.Core.Vault.Models.Data; @@ -21,6 +24,64 @@ namespace Bit.Core.Test.Services; [SutProviderCustomize] public class CipherServiceTests { + [Theory, BitAutoData] + public async Task ImportCiphersAsync_IntoOrganization_Success( + Organization organization, + Guid importingUserId, + OrganizationUser importingOrganizationUser, + List collections, + List ciphers, + SutProvider sutProvider) + { + organization.MaxCollections = null; + importingOrganizationUser.OrganizationId = organization.Id; + + foreach (var collection in collections) + { + collection.OrganizationId = organization.Id; + } + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organization.Id; + } + + KeyValuePair[] collectionRelationships = { + new(0, 0), + new(1, 1), + new(2, 2) + }; + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + sutProvider.GetDependency() + .GetByOrganizationAsync(organization.Id, importingUserId) + .Returns(importingOrganizationUser); + + // Set up a collection that already exists in the organization + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(organization.Id) + .Returns(new List { collections[0] }); + + await sutProvider.Sut.ImportCiphersAsync(collections, ciphers, collectionRelationships, importingUserId); + + await sutProvider.GetDependency().Received(1).CreateAsync( + ciphers, + Arg.Is>(cols => cols.Count() == collections.Count - 1 && + !cols.Any(c => c.Id == collections[0].Id) && // Check that the collection that already existed in the organization was not added + cols.All(c => collections.Any(x => c.Name == x.Name))), + Arg.Is>(c => c.Count() == ciphers.Count), + Arg.Is>(cus => + cus.Count() == collections.Count - 1 && + !cus.Any(cu => cu.CollectionId == collections[0].Id) && // Check that access was not added for the collection that already existed in the organization + cus.All(cu => cu.OrganizationUserId == importingOrganizationUser.Id && cu.Manage == true))); + await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); + await sutProvider.GetDependency().Received(1).RaiseEventAsync( + Arg.Is(e => e.Type == ReferenceEventType.VaultImported)); + } + [Theory, BitAutoData] public async Task SaveAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher) {