From abe593d22122cda682b7b789a4eadb6ceccfad85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Wed, 2 Apr 2025 10:52:23 +0100 Subject: [PATCH 01/15] [PM-18088] Implement LimitItemDeletion permission checks for all cipher operations (#5476) * Implement enhanced cipher deletion and restore permissions with feature flag support - Add new method `CanDeleteOrRestoreCipherAsAdminAsync` in CiphersController - Update NormalCipherPermissions to support more flexible cipher type checking - Modify CipherService to use new permission checks with feature flag - Refactor test methods to support new permission logic - Improve authorization checks for organization cipher management * Refactor cipher methods to use CipherDetails and simplify type handling - Update CiphersController to use GetByIdAsync with userId - Modify NormalCipherPermissions to remove unnecessary type casting - Update ICipherService and CipherService method signatures to use CipherDetails - Remove redundant type checking in CipherService methods - Improve type consistency in cipher-related operations * Enhance CiphersControllerTests with detailed permission and feature flag scenarios - Add test methods for DeleteAdmin with edit and manage permission checks - Implement tests for LimitItemDeletion feature flag scenarios - Update test method names to reflect more precise permission conditions - Improve test coverage for admin cipher deletion with granular permission handling * Add comprehensive test coverage for admin cipher restore operations - Implement test methods for PutRestoreAdmin and PutRestoreManyAdmin - Add scenarios for owner and admin roles with LimitItemDeletion feature flag - Cover permission checks for manage and edit permissions - Enhance test coverage for single and bulk cipher restore admin operations - Verify correct invocation of RestoreAsync and RestoreManyAsync methods * Refactor CiphersControllerTests to remove redundant assertions and mocking - Remove unnecessary assertions for null checks - Simplify mocking setup for cipher repository and service methods - Clean up redundant type and data setup in test methods - Improve test method clarity by removing extraneous code * Add comprehensive test coverage for cipher restore, delete, and soft delete operations - Implement test methods for RestoreAsync with org admin override and LimitItemDeletion feature flag - Add scenarios for checking manage and edit permissions during restore operations - Extend test coverage for DeleteAsync with similar permission and feature flag checks - Enhance SoftDeleteAsync tests with org admin override and permission validation - Improve test method names to reflect precise permission conditions * Add comprehensive test coverage for cipher restore, delete, and soft delete operations - Extend test methods for RestoreManyAsync with various permission scenarios - Add test coverage for personal and organization ciphers in restore operations - Implement tests for RestoreManyAsync with LimitItemDeletion feature flag - Add detailed test scenarios for delete and soft delete operations - Improve test method names to reflect precise permission and feature flag conditions * Refactor authorization checks in CiphersController to use All() method for improved readability * Refactor filtering of ciphers in CipherService to streamline organization ability checks and improve readability --- .../Vault/Controllers/CiphersController.cs | 96 +- src/Core/Vault/Services/ICipherService.cs | 6 +- .../Services/Implementations/CipherService.cs | 131 +- .../Controllers/CiphersControllerTests.cs | 957 +++++++++++--- .../Vault/Services/CipherServiceTests.cs | 1155 +++++++++++++++-- 5 files changed, 2067 insertions(+), 278 deletions(-) diff --git a/src/Api/Vault/Controllers/CiphersController.cs b/src/Api/Vault/Controllers/CiphersController.cs index daaf8a03fb..0f03f54be1 100644 --- a/src/Api/Vault/Controllers/CiphersController.cs +++ b/src/Api/Vault/Controllers/CiphersController.cs @@ -16,6 +16,7 @@ using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Tools.Services; using Bit.Core.Utilities; +using Bit.Core.Vault.Authorization.Permissions; using Bit.Core.Vault.Entities; using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Queries; @@ -345,6 +346,77 @@ public class CiphersController : Controller return await CanEditCiphersAsync(organizationId, cipherIds); } + private async Task CanDeleteOrRestoreCipherAsAdminAsync(Guid organizationId, IEnumerable cipherIds) + { + if (!_featureService.IsEnabled(FeatureFlagKeys.LimitItemDeletion)) + { + return await CanEditCipherAsAdminAsync(organizationId, cipherIds); + } + + var org = _currentContext.GetOrganization(organizationId); + + // If we're not an "admin", we don't need to check the ciphers + if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true })) + { + // Are we a provider user? If so, we need to be sure we're not restricted + // Once the feature flag is removed, this check can be combined with the above + if (await _currentContext.ProviderUserForOrgAsync(organizationId)) + { + // Provider is restricted from editing ciphers, so we're not an "admin" + if (_featureService.IsEnabled(FeatureFlagKeys.RestrictProviderAccess)) + { + return false; + } + + // Provider is unrestricted, so we're an "admin", don't return early + } + else + { + // Not a provider or admin + return false; + } + } + + // If the user can edit all ciphers for the organization, just check they all belong to the org + if (await CanEditAllCiphersAsync(organizationId)) + { + // TODO: This can likely be optimized to only query the requested ciphers and then checking they belong to the org + var orgCiphers = (await _cipherRepository.GetManyByOrganizationIdAsync(organizationId)).ToDictionary(c => c.Id); + + // Ensure all requested ciphers are in orgCiphers + return cipherIds.All(c => orgCiphers.ContainsKey(c)); + } + + // The user cannot access any ciphers for the organization, we're done + if (!await CanAccessOrganizationCiphersAsync(organizationId)) + { + return false; + } + + var user = await _userService.GetUserByPrincipalAsync(User); + // Select all deletable ciphers for this user belonging to the organization + var deletableOrgCipherList = (await _cipherRepository.GetManyByUserIdAsync(user.Id, true)) + .Where(c => c.OrganizationId == organizationId && c.UserId == null).ToList(); + + // Special case for unassigned ciphers + if (await CanAccessUnassignedCiphersAsync(organizationId)) + { + var unassignedCiphers = + (await _cipherRepository.GetManyUnassignedOrganizationDetailsByOrganizationIdAsync( + organizationId)); + + // Users that can access unassigned ciphers can also delete them + deletableOrgCipherList.AddRange(unassignedCiphers.Select(c => new CipherDetails(c) { Manage = true })); + } + + var organizationAbility = await _applicationCacheService.GetOrganizationAbilityAsync(organizationId); + var deletableOrgCiphers = deletableOrgCipherList + .Where(c => NormalCipherPermissions.CanDelete(user, c, organizationAbility)) + .ToDictionary(c => c.Id); + + return cipherIds.All(c => deletableOrgCiphers.ContainsKey(c)); + } + /// /// TODO: Move this to its own authorization handler or equivalent service - AC-2062 /// @@ -763,12 +835,12 @@ public class CiphersController : Controller [HttpDelete("{id}/admin")] [HttpPost("{id}/delete-admin")] - public async Task DeleteAdmin(string id) + public async Task DeleteAdmin(Guid id) { var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + var cipher = await GetByIdAsync(id, userId); if (cipher == null || !cipher.OrganizationId.HasValue || - !await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })) + !await CanDeleteOrRestoreCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })) { throw new NotFoundException(); } @@ -808,7 +880,7 @@ public class CiphersController : Controller var cipherIds = model.Ids.Select(i => new Guid(i)).ToList(); if (string.IsNullOrWhiteSpace(model.OrganizationId) || - !await CanEditCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds)) + !await CanDeleteOrRestoreCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds)) { throw new NotFoundException(); } @@ -830,12 +902,12 @@ public class CiphersController : Controller } [HttpPut("{id}/delete-admin")] - public async Task PutDeleteAdmin(string id) + public async Task PutDeleteAdmin(Guid id) { var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + var cipher = await GetByIdAsync(id, userId); if (cipher == null || !cipher.OrganizationId.HasValue || - !await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })) + !await CanDeleteOrRestoreCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })) { throw new NotFoundException(); } @@ -871,7 +943,7 @@ public class CiphersController : Controller var cipherIds = model.Ids.Select(i => new Guid(i)).ToList(); if (string.IsNullOrWhiteSpace(model.OrganizationId) || - !await CanEditCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds)) + !await CanDeleteOrRestoreCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds)) { throw new NotFoundException(); } @@ -899,12 +971,12 @@ public class CiphersController : Controller } [HttpPut("{id}/restore-admin")] - public async Task PutRestoreAdmin(string id) + public async Task PutRestoreAdmin(Guid id) { var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id)); + var cipher = await GetByIdAsync(id, userId); if (cipher == null || !cipher.OrganizationId.HasValue || - !await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })) + !await CanDeleteOrRestoreCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })) { throw new NotFoundException(); } @@ -944,7 +1016,7 @@ public class CiphersController : Controller var cipherIdsToRestore = new HashSet(model.Ids.Select(i => new Guid(i))); - if (model.OrganizationId == default || !await CanEditCipherAsAdminAsync(model.OrganizationId, cipherIdsToRestore)) + if (model.OrganizationId == default || !await CanDeleteOrRestoreCipherAsAdminAsync(model.OrganizationId, cipherIdsToRestore)) { throw new NotFoundException(); } diff --git a/src/Core/Vault/Services/ICipherService.cs b/src/Core/Vault/Services/ICipherService.cs index 17f55cb47d..7eeb6d2463 100644 --- a/src/Core/Vault/Services/ICipherService.cs +++ b/src/Core/Vault/Services/ICipherService.cs @@ -15,7 +15,7 @@ public interface ICipherService long requestLength, Guid savingUserId, bool orgAdmin = false); Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, string fileName, string key, long requestLength, string attachmentId, Guid organizationShareId); - Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); + Task DeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false); Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, bool orgAdmin = false); Task PurgeAsync(Guid organizationId); @@ -27,9 +27,9 @@ public interface ICipherService Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> ciphers, Guid organizationId, IEnumerable collectionIds, Guid sharingUserId); Task SaveCollectionsAsync(Cipher cipher, IEnumerable collectionIds, Guid savingUserId, bool orgAdmin); - Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); + Task SoftDeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false); Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); - Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false); + Task RestoreAsync(CipherDetails cipherDetails, Guid restoringUserId, bool orgAdmin = false); Task> RestoreManyAsync(IEnumerable cipherIds, Guid restoringUserId, Guid? organizationId = null, bool orgAdmin = false); Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId); Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId); diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index b9daafe599..989fbf43b8 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -14,6 +14,7 @@ using Bit.Core.Tools.Enums; using Bit.Core.Tools.Models.Business; using Bit.Core.Tools.Services; using Bit.Core.Utilities; +using Bit.Core.Vault.Authorization.Permissions; using Bit.Core.Vault.Entities; using Bit.Core.Vault.Enums; using Bit.Core.Vault.Models.Data; @@ -44,6 +45,7 @@ public class CipherService : ICipherService private readonly ICurrentContext _currentContext; private readonly IGetCipherPermissionsForUserQuery _getCipherPermissionsForUserQuery; private readonly IPolicyRequirementQuery _policyRequirementQuery; + private readonly IApplicationCacheService _applicationCacheService; private readonly IFeatureService _featureService; public CipherService( @@ -64,6 +66,7 @@ public class CipherService : ICipherService ICurrentContext currentContext, IGetCipherPermissionsForUserQuery getCipherPermissionsForUserQuery, IPolicyRequirementQuery policyRequirementQuery, + IApplicationCacheService applicationCacheService, IFeatureService featureService) { _cipherRepository = cipherRepository; @@ -83,6 +86,7 @@ public class CipherService : ICipherService _currentContext = currentContext; _getCipherPermissionsForUserQuery = getCipherPermissionsForUserQuery; _policyRequirementQuery = policyRequirementQuery; + _applicationCacheService = applicationCacheService; _featureService = featureService; } @@ -421,19 +425,19 @@ public class CipherService : ICipherService return response; } - public async Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) + public async Task DeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false) { - if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) + if (!orgAdmin && !await UserCanDeleteAsync(cipherDetails, deletingUserId)) { throw new BadRequestException("You do not have permissions to delete this."); } - await _cipherRepository.DeleteAsync(cipher); - await _attachmentStorageService.DeleteAttachmentsForCipherAsync(cipher.Id); - await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Deleted); + await _cipherRepository.DeleteAsync(cipherDetails); + await _attachmentStorageService.DeleteAttachmentsForCipherAsync(cipherDetails.Id); + await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_Deleted); // push - await _pushService.PushSyncCipherDeleteAsync(cipher); + await _pushService.PushSyncCipherDeleteAsync(cipherDetails); } public async Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false) @@ -450,8 +454,8 @@ public class CipherService : ICipherService else { var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); - deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); - + var filteredCiphers = await FilterCiphersByDeletePermission(ciphers, cipherIdsSet, deletingUserId); + deletingCiphers = filteredCiphers.Select(c => (Cipher)c).ToList(); await _cipherRepository.DeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); } @@ -703,33 +707,26 @@ public class CipherService : ICipherService await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); } - public async Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) + public async Task SoftDeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false) { - if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) + if (!orgAdmin && !await UserCanDeleteAsync(cipherDetails, deletingUserId)) { throw new BadRequestException("You do not have permissions to soft delete this."); } - if (cipher.DeletedDate.HasValue) + if (cipherDetails.DeletedDate.HasValue) { // Already soft-deleted, we can safely ignore this return; } - cipher.DeletedDate = cipher.RevisionDate = DateTime.UtcNow; + cipherDetails.DeletedDate = cipherDetails.RevisionDate = DateTime.UtcNow; - if (cipher is CipherDetails details) - { - await _cipherRepository.UpsertAsync(details); - } - else - { - await _cipherRepository.UpsertAsync(cipher); - } - await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_SoftDeleted); + await _cipherRepository.UpsertAsync(cipherDetails); + await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted); // push - await _pushService.PushSyncCipherUpdateAsync(cipher, null); + await _pushService.PushSyncCipherUpdateAsync(cipherDetails, null); } public async Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId, bool orgAdmin) @@ -746,8 +743,8 @@ public class CipherService : ICipherService else { var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); - deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); - + var filteredCiphers = await FilterCiphersByDeletePermission(ciphers, cipherIdsSet, deletingUserId); + deletingCiphers = filteredCiphers.Select(c => (Cipher)c).ToList(); await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); } @@ -762,34 +759,27 @@ public class CipherService : ICipherService await _pushService.PushSyncCiphersAsync(deletingUserId); } - public async Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false) + public async Task RestoreAsync(CipherDetails cipherDetails, Guid restoringUserId, bool orgAdmin = false) { - if (!orgAdmin && !(await UserCanEditAsync(cipher, restoringUserId))) + if (!orgAdmin && !await UserCanRestoreAsync(cipherDetails, restoringUserId)) { throw new BadRequestException("You do not have permissions to delete this."); } - if (!cipher.DeletedDate.HasValue) + if (!cipherDetails.DeletedDate.HasValue) { // Already restored, we can safely ignore this return; } - cipher.DeletedDate = null; - cipher.RevisionDate = DateTime.UtcNow; + cipherDetails.DeletedDate = null; + cipherDetails.RevisionDate = DateTime.UtcNow; - if (cipher is CipherDetails details) - { - await _cipherRepository.UpsertAsync(details); - } - else - { - await _cipherRepository.UpsertAsync(cipher); - } - await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Restored); + await _cipherRepository.UpsertAsync(cipherDetails); + await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_Restored); // push - await _pushService.PushSyncCipherUpdateAsync(cipher, null); + await _pushService.PushSyncCipherUpdateAsync(cipherDetails, null); } public async Task> RestoreManyAsync(IEnumerable cipherIds, Guid restoringUserId, Guid? organizationId = null, bool orgAdmin = false) @@ -812,8 +802,8 @@ public class CipherService : ICipherService else { var ciphers = await _cipherRepository.GetManyByUserIdAsync(restoringUserId); - restoringCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(c => (CipherOrganizationDetails)c).ToList(); - + var filteredCiphers = await FilterCiphersByDeletePermission(ciphers, cipherIdsSet, restoringUserId); + restoringCiphers = filteredCiphers.Select(c => (CipherOrganizationDetails)c).ToList(); revisionDate = await _cipherRepository.RestoreAsync(restoringCiphers.Select(c => c.Id), restoringUserId); } @@ -844,6 +834,34 @@ public class CipherService : ICipherService return await _cipherRepository.GetCanEditByIdAsync(userId, cipher.Id); } + private async Task UserCanDeleteAsync(CipherDetails cipher, Guid userId) + { + if (!_featureService.IsEnabled(FeatureFlagKeys.LimitItemDeletion)) + { + return await UserCanEditAsync(cipher, userId); + } + + var user = await _userService.GetUserByIdAsync(userId); + var organizationAbility = cipher.OrganizationId.HasValue ? + await _applicationCacheService.GetOrganizationAbilityAsync(cipher.OrganizationId.Value) : null; + + return NormalCipherPermissions.CanDelete(user, cipher, organizationAbility); + } + + private async Task UserCanRestoreAsync(CipherDetails cipher, Guid userId) + { + if (!_featureService.IsEnabled(FeatureFlagKeys.LimitItemDeletion)) + { + return await UserCanEditAsync(cipher, userId); + } + + var user = await _userService.GetUserByIdAsync(userId); + var organizationAbility = cipher.OrganizationId.HasValue ? + await _applicationCacheService.GetOrganizationAbilityAsync(cipher.OrganizationId.Value) : null; + + return NormalCipherPermissions.CanRestore(user, cipher, organizationAbility); + } + private void ValidateCipherLastKnownRevisionDateAsync(Cipher cipher, DateTime? lastKnownRevisionDate) { if (cipher.Id == default || !lastKnownRevisionDate.HasValue) @@ -1010,4 +1028,35 @@ public class CipherService : ICipherService cipher.Data = JsonSerializer.Serialize(newCipherData); } } + + // This method is used to filter ciphers based on the user's permissions to delete them. + // It supports both the old and new logic depending on the feature flag. + private async Task> FilterCiphersByDeletePermission( + IEnumerable ciphers, + HashSet cipherIdsSet, + Guid userId) where T : CipherDetails + { + if (!_featureService.IsEnabled(FeatureFlagKeys.LimitItemDeletion)) + { + return ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).ToList(); + } + + var user = await _userService.GetUserByIdAsync(userId); + var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + + var filteredCiphers = ciphers + .Where(c => cipherIdsSet.Contains(c.Id)) + .GroupBy(c => c.OrganizationId) + .SelectMany(group => + { + var organizationAbility = group.Key.HasValue && + organizationAbilities.TryGetValue(group.Key.Value, out var ability) ? + ability : null; + + return group.Where(c => NormalCipherPermissions.CanDelete(user, c, organizationAbility)); + }) + .ToList(); + + return filteredCiphers; + } } diff --git a/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs b/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs index 14013d9c1c..0bdc6ab545 100644 --- a/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs +++ b/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs @@ -157,9 +157,9 @@ public class CiphersControllerTests [BitAutoData(OrganizationUserType.Custom, false, false)] public async Task CanEditCiphersAsAdminAsync_FlexibleCollections_Success( OrganizationUserType userType, bool allowAdminsAccessToAllItems, bool shouldSucceed, - CurrentContextOrganization organization, Guid userId, Cipher cipher, SutProvider sutProvider) + CurrentContextOrganization organization, Guid userId, CipherDetails cipherDetails, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; + cipherDetails.OrganizationId = organization.Id; organization.Type = userType; if (userType == OrganizationUserType.Custom) { @@ -169,8 +169,9 @@ public class CiphersControllerTests sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipher }); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); sutProvider.GetDependency().GetOrganizationAbilityAsync(organization.Id).Returns(new OrganizationAbility { @@ -180,13 +181,13 @@ public class CiphersControllerTests if (shouldSucceed) { - await sutProvider.Sut.DeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); await sutProvider.GetDependency().ReceivedWithAnyArgs() .DeleteAsync(default, default); } else { - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipher.Id.ToString())); + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipherDetails.Id)); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .DeleteAsync(default, default); } @@ -197,10 +198,10 @@ public class CiphersControllerTests [BitAutoData(false)] [BitAutoData(true)] public async Task CanEditCiphersAsAdminAsync_Providers( - bool restrictProviders, Cipher cipher, CurrentContextOrganization organization, Guid userId, SutProvider sutProvider + bool restrictProviders, CipherDetails cipherDetails, CurrentContextOrganization organization, Guid userId, SutProvider sutProvider ) { - cipher.OrganizationId = organization.Id; + cipherDetails.OrganizationId = organization.Id; // Simulate that the user is a provider for the organization sutProvider.GetDependency().EditAnyCollection(organization.Id).Returns(true); @@ -208,8 +209,8 @@ public class CiphersControllerTests sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipher }); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); sutProvider.GetDependency().GetOrganizationAbilityAsync(organization.Id).Returns(new OrganizationAbility { @@ -221,13 +222,13 @@ public class CiphersControllerTests // Non restricted providers should succeed if (!restrictProviders) { - await sutProvider.Sut.DeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); await sutProvider.GetDependency().ReceivedWithAnyArgs() .DeleteAsync(default, default); } else // Otherwise, they should fail { - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipher.Id.ToString())); + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipherDetails.Id)); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .DeleteAsync(default, default); } @@ -238,93 +239,202 @@ public class CiphersControllerTests [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] - public async Task DeleteAdmin_WithOwnerOrAdmin_WithAccessToSpecificCipher_DeletesCipher( - OrganizationUserType organizationUserType, Cipher cipher, Guid userId, + public async Task DeleteAdmin_WithOwnerOrAdmin_WithEditPermission_DeletesCipher( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Edit = true; + cipherDetails.Manage = false; + organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); sutProvider.GetDependency() .GetManyByUserIdAsync(userId) .Returns(new List { - new() { Id = cipher.Id, OrganizationId = cipher.OrganizationId, Edit = true } + cipherDetails }); - await sutProvider.Sut.DeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails, userId, true); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task DeleteAdmin_WithOwnerOrAdmin_WithoutEditPermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + CurrentContextOrganization organization, SutProvider sutProvider) + { + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Edit = false; + cipherDetails.Manage = false; + + organization.Type = organizationUserType; + + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(new List + { + cipherDetails + }); + + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipherDetails.Id)); + + await sutProvider.GetDependency().DidNotReceive().DeleteAsync(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task DeleteAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithManagePermission_DeletesCipher( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + CurrentContextOrganization organization, SutProvider sutProvider) + { + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Edit = true; + cipherDetails.Manage = true; + + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(new List + { + cipherDetails + }); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); + + await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails, userId, true); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task DeleteAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithoutManagePermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + CurrentContextOrganization organization, SutProvider sutProvider) + { + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Edit = true; + cipherDetails.Manage = false; + + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(new List + { + cipherDetails + }); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipherDetails.Id)); + + await sutProvider.GetDependency().DidNotReceive().DeleteAsync(Arg.Any(), Arg.Any(), Arg.Any()); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] public async Task DeleteAdmin_WithOwnerOrAdmin_WithAccessToUnassignedCipher_DeletesCipher( - OrganizationUserType organizationUserType, Cipher cipher, Guid userId, + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; + cipherDetails.OrganizationId = organization.Id; organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); sutProvider.GetDependency() .GetManyUnassignedOrganizationDetailsByOrganizationIdAsync(organization.Id) - .Returns(new List { new() { Id = cipher.Id } }); + .Returns(new List { new() { Id = cipherDetails.Id } }); - await sutProvider.Sut.DeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails, userId, true); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] - public async Task DeleteAdmin_WithAdminOrOwnerAndAccessToAllCollectionItems_DeletesCipher( - OrganizationUserType organizationUserType, Cipher cipher, Guid userId, + public async Task DeleteAdmin_WithAdminOrOwner_WithAccessToAllCollectionItems_DeletesCipher( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; + cipherDetails.OrganizationId = organization.Id; organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipher }); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); sutProvider.GetDependency().GetOrganizationAbilityAsync(organization.Id).Returns(new OrganizationAbility { Id = organization.Id, AllowAdminAccessToAllCollectionItems = true }); - await sutProvider.Sut.DeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails, userId, true); } [Theory] [BitAutoData] public async Task DeleteAdmin_WithCustomUser_WithEditAnyCollectionTrue_DeletesCipher( - Cipher cipher, Guid userId, + CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; + cipherDetails.OrganizationId = organization.Id; organization.Type = OrganizationUserType.Custom; organization.Permissions.EditAnyCollection = true; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipher }); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); - await sutProvider.Sut.DeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails, userId, true); } [Theory] @@ -341,24 +451,24 @@ public class CiphersControllerTests sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipher.Id.ToString())); + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipher.Id)); } [Theory] [BitAutoData] public async Task DeleteAdmin_WithProviderUser_DeletesCipher( - Cipher cipher, Guid userId, SutProvider sutProvider) + CipherDetails cipherDetails, Guid userId, SutProvider sutProvider) { - cipher.OrganizationId = Guid.NewGuid(); + cipherDetails.OrganizationId = Guid.NewGuid(); sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().ProviderUserForOrgAsync(cipher.OrganizationId.Value).Returns(true); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(cipher.OrganizationId.Value).Returns(new List { cipher }); + sutProvider.GetDependency().ProviderUserForOrgAsync(cipherDetails.OrganizationId.Value).Returns(true); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(cipherDetails.OrganizationId.Value).Returns(new List { cipherDetails }); - await sutProvider.Sut.DeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails, userId, true); } [Theory] @@ -373,13 +483,13 @@ public class CiphersControllerTests sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.RestrictProviderAccess).Returns(true); - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipher.Id.ToString())); + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipher.Id)); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] - public async Task DeleteManyAdmin_WithOwnerOrAdmin_WithAccessToSpecificCiphers_DeletesCiphers( + public async Task DeleteManyAdmin_WithOwnerOrAdmin_WithEditPermission_DeletesCiphers( OrganizationUserType organizationUserType, CipherBulkDeleteRequestModel model, Guid userId, List ciphers, CurrentContextOrganization organization, SutProvider sutProvider) { @@ -408,6 +518,122 @@ public class CiphersControllerTests userId, organization.Id, true); } + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task DeleteManyAdmin_WithOwnerOrAdmin_WithoutEditPermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherBulkDeleteRequestModel model, Guid userId, List ciphers, + CurrentContextOrganization organization, SutProvider sutProvider) + { + model.OrganizationId = organization.Id.ToString(); + model.Ids = ciphers.Select(c => c.Id.ToString()).ToList(); + + organization.Type = organizationUserType; + + sutProvider.GetDependency() + .GetProperUserId(default) + .ReturnsForAnyArgs(userId); + + sutProvider.GetDependency() + .GetOrganization(new Guid(model.OrganizationId)) + .Returns(organization); + + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(new Guid(model.OrganizationId)) + .Returns(ciphers); + + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(new Guid(model.OrganizationId)) + .Returns(new OrganizationAbility + { + Id = new Guid(model.OrganizationId), + AllowAdminAccessToAllCollectionItems = false, + }); + + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteManyAdmin(model)); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task DeleteManyAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithManagePermission_DeletesCiphers( + OrganizationUserType organizationUserType, CipherBulkDeleteRequestModel model, Guid userId, List ciphers, + CurrentContextOrganization organization, SutProvider sutProvider) + { + model.OrganizationId = organization.Id.ToString(); + model.Ids = ciphers.Select(c => c.Id.ToString()).ToList(); + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organization.Id, + Edit = true, + Manage = true + }).ToList()); + + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + await sutProvider.Sut.DeleteManyAdmin(model); + + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync( + Arg.Is>(ids => + ids.All(id => model.Ids.Contains(id.ToString())) && ids.Count() == model.Ids.Count()), + userId, organization.Id, true); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task DeleteManyAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithoutManagePermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherBulkDeleteRequestModel model, Guid userId, List ciphers, + CurrentContextOrganization organization, SutProvider sutProvider) + { + model.OrganizationId = organization.Id.ToString(); + model.Ids = ciphers.Select(c => c.Id.ToString()).ToList(); + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organization.Id, + Edit = true, + Manage = false + }).ToList()); + + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteManyAdmin(model)); + } + [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] @@ -555,94 +781,203 @@ public class CiphersControllerTests [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] - public async Task PutDeleteAdmin_WithOwnerOrAdmin_WithAccessToSpecificCipher_SoftDeletesCipher( - OrganizationUserType organizationUserType, Cipher cipher, Guid userId, + public async Task PutDeleteAdmin_WithOwnerOrAdmin_WithEditPermission_SoftDeletesCipher( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Edit = true; + organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); sutProvider.GetDependency() .GetManyByUserIdAsync(userId) .Returns(new List { - new() { Id = cipher.Id, OrganizationId = cipher.OrganizationId, Edit = true } + cipherDetails }); - await sutProvider.Sut.PutDeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.PutDeleteAdmin(cipherDetails.Id); - await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipherDetails, userId, true); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutDeleteAdmin_WithOwnerOrAdmin_WithoutEditPermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + CurrentContextOrganization organization, SutProvider sutProvider) + { + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Edit = false; + cipherDetails.Manage = false; + + organization.Type = organizationUserType; + + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(new List + { + cipherDetails + }); + + await Assert.ThrowsAsync(() => sutProvider.Sut.PutDeleteAdmin(cipherDetails.Id)); + + await sutProvider.GetDependency().DidNotReceive().SoftDeleteAsync(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutDeleteAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithManagePermission_SoftDeletesCipher( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + CurrentContextOrganization organization, SutProvider sutProvider) + { + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Edit = true; + cipherDetails.Manage = true; + + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(new List + { + cipherDetails + }); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + await sutProvider.Sut.PutDeleteAdmin(cipherDetails.Id); + + await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipherDetails, userId, true); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutDeleteAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithoutManagePermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + CurrentContextOrganization organization, SutProvider sutProvider) + { + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Edit = true; + cipherDetails.Manage = false; + + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(new List + { + cipherDetails + }); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + await Assert.ThrowsAsync(() => sutProvider.Sut.PutDeleteAdmin(cipherDetails.Id)); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SoftDeleteManyAsync(default, default, default, default); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] public async Task PutDeleteAdmin_WithOwnerOrAdmin_WithAccessToUnassignedCipher_SoftDeletesCipher( - OrganizationUserType organizationUserType, Cipher cipher, Guid userId, + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; + cipherDetails.OrganizationId = organization.Id; organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); sutProvider.GetDependency() .GetManyUnassignedOrganizationDetailsByOrganizationIdAsync(organization.Id) - .Returns(new List { new() { Id = cipher.Id } }); + .Returns(new List { new() { Id = cipherDetails.Id } }); - await sutProvider.Sut.PutDeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.PutDeleteAdmin(cipherDetails.Id); - await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipherDetails, userId, true); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] public async Task PutDeleteAdmin_WithOwnerOrAdmin_WithAccessToAllCollectionItems_SoftDeletesCipher( - OrganizationUserType organizationUserType, Cipher cipher, Guid userId, + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; + cipherDetails.OrganizationId = organization.Id; organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipher }); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); sutProvider.GetDependency().GetOrganizationAbilityAsync(organization.Id).Returns(new OrganizationAbility { Id = organization.Id, AllowAdminAccessToAllCollectionItems = true }); - await sutProvider.Sut.PutDeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.PutDeleteAdmin(cipherDetails.Id); - await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipherDetails, userId, true); } [Theory] [BitAutoData] public async Task PutDeleteAdmin_WithCustomUser_WithEditAnyCollectionTrue_SoftDeletesCipher( - Cipher cipher, Guid userId, + CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; + cipherDetails.OrganizationId = organization.Id; organization.Type = OrganizationUserType.Custom; organization.Permissions.EditAnyCollection = true; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipher }); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); - await sutProvider.Sut.PutDeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.PutDeleteAdmin(cipherDetails.Id); - await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipherDetails, userId, true); } [Theory] @@ -660,24 +995,24 @@ public class CiphersControllerTests sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipher }); - await Assert.ThrowsAsync(() => sutProvider.Sut.PutDeleteAdmin(cipher.Id.ToString())); + await Assert.ThrowsAsync(() => sutProvider.Sut.PutDeleteAdmin(cipher.Id)); } [Theory] [BitAutoData] public async Task PutDeleteAdmin_WithProviderUser_SoftDeletesCipher( - Cipher cipher, Guid userId, SutProvider sutProvider) + CipherDetails cipherDetails, Guid userId, SutProvider sutProvider) { - cipher.OrganizationId = Guid.NewGuid(); + cipherDetails.OrganizationId = Guid.NewGuid(); sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().ProviderUserForOrgAsync(cipher.OrganizationId.Value).Returns(true); - sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(cipher.OrganizationId.Value).Returns(new List { cipher }); + sutProvider.GetDependency().ProviderUserForOrgAsync(cipherDetails.OrganizationId.Value).Returns(true); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(cipherDetails.OrganizationId.Value).Returns(new List { cipherDetails }); - await sutProvider.Sut.PutDeleteAdmin(cipher.Id.ToString()); + await sutProvider.Sut.PutDeleteAdmin(cipherDetails.Id); - await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).SoftDeleteAsync(cipherDetails, userId, true); } [Theory] @@ -692,13 +1027,13 @@ public class CiphersControllerTests sutProvider.GetDependency().GetByIdAsync(cipher.Id).Returns(cipher); sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.RestrictProviderAccess).Returns(true); - await Assert.ThrowsAsync(() => sutProvider.Sut.PutDeleteAdmin(cipher.Id.ToString())); + await Assert.ThrowsAsync(() => sutProvider.Sut.PutDeleteAdmin(cipher.Id)); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] - public async Task PutDeleteManyAdmin_WithOwnerOrAdmin_WithAccessToSpecificCiphers_SoftDeletesCiphers( + public async Task PutDeleteManyAdmin_WithOwnerOrAdmin_WithEditPermission_SoftDeletesCiphers( OrganizationUserType organizationUserType, CipherBulkDeleteRequestModel model, Guid userId, List ciphers, CurrentContextOrganization organization, SutProvider sutProvider) { @@ -727,6 +1062,113 @@ public class CiphersControllerTests userId, organization.Id, true); } + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutDeleteManyAdmin_WithOwnerOrAdmin_WithoutEditPermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherBulkDeleteRequestModel model, Guid userId, List ciphers, + CurrentContextOrganization organization, SutProvider sutProvider) + { + model.OrganizationId = organization.Id.ToString(); + model.Ids = ciphers.Select(c => c.Id.ToString()).ToList(); + organization.Type = organizationUserType; + + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organization.Id, + Edit = false + }).ToList()); + + await Assert.ThrowsAsync(() => sutProvider.Sut.PutDeleteManyAdmin(model)); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutDeleteManyAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithManagePermission_SoftDeletesCiphers( + OrganizationUserType organizationUserType, CipherBulkDeleteRequestModel model, Guid userId, List ciphers, + CurrentContextOrganization organization, SutProvider sutProvider) + { + model.OrganizationId = organization.Id.ToString(); + model.Ids = ciphers.Select(c => c.Id.ToString()).ToList(); + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organization.Id, + Edit = true, + Manage = true + }).ToList()); + + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + await sutProvider.Sut.PutDeleteManyAdmin(model); + + await sutProvider.GetDependency() + .Received(1) + .SoftDeleteManyAsync( + Arg.Is>(ids => + ids.All(id => model.Ids.Contains(id.ToString())) && ids.Count() == model.Ids.Count()), + userId, organization.Id, true); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutDeleteManyAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithoutManagePermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherBulkDeleteRequestModel model, Guid userId, List ciphers, + CurrentContextOrganization organization, SutProvider sutProvider) + { + model.OrganizationId = organization.Id.ToString(); + model.Ids = ciphers.Select(c => c.Id.ToString()).ToList(); + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organization.Id, + Edit = true, + Manage = false + }).ToList()); + + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + await Assert.ThrowsAsync(() => sutProvider.Sut.PutDeleteManyAdmin(model)); + } + [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] @@ -874,170 +1316,273 @@ public class CiphersControllerTests [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] - public async Task PutRestoreAdmin_WithOwnerOrAdmin_WithAccessToSpecificCipher_RestoresCipher( - OrganizationUserType organizationUserType, CipherDetails cipher, Guid userId, + public async Task PutRestoreAdmin_WithOwnerOrAdmin_WithEditPermission_RestoresCipher( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; - cipher.Type = CipherType.Login; - cipher.Data = JsonSerializer.Serialize(new CipherLoginData()); + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Type = CipherType.Login; + cipherDetails.Data = JsonSerializer.Serialize(new CipherLoginData()); + cipherDetails.Edit = true; + organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); sutProvider.GetDependency() .GetManyByUserIdAsync(userId) .Returns(new List { - new() { Id = cipher.Id, OrganizationId = cipher.OrganizationId, Edit = true } + cipherDetails }); - var result = await sutProvider.Sut.PutRestoreAdmin(cipher.Id.ToString()); + var result = await sutProvider.Sut.PutRestoreAdmin(cipherDetails.Id); - Assert.NotNull(result); Assert.IsType(result); - await sutProvider.GetDependency().Received(1).RestoreAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).RestoreAsync(cipherDetails, userId, true); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutRestoreAdmin_WithOwnerOrAdmin_WithoutEditPermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + CurrentContextOrganization organization, SutProvider sutProvider) + { + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Edit = false; + cipherDetails.Manage = false; + + organization.Type = organizationUserType; + + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(new List + { + cipherDetails + }); + + await Assert.ThrowsAsync(() => sutProvider.Sut.PutRestoreAdmin(cipherDetails.Id)); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutRestoreAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithManagePermission_RestoresCipher( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + CurrentContextOrganization organization, SutProvider sutProvider) + { + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Type = CipherType.Login; + cipherDetails.Data = JsonSerializer.Serialize(new CipherLoginData()); + cipherDetails.Edit = true; + cipherDetails.Manage = true; + + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(new List + { + cipherDetails + }); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + var result = await sutProvider.Sut.PutRestoreAdmin(cipherDetails.Id); + + Assert.IsType(result); + await sutProvider.GetDependency().Received(1).RestoreAsync(cipherDetails, userId, true); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutRestoreAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithoutManagePermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + CurrentContextOrganization organization, SutProvider sutProvider) + { + cipherDetails.UserId = null; + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Edit = true; + cipherDetails.Manage = false; + + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(new List + { + cipherDetails + }); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + await Assert.ThrowsAsync(() => sutProvider.Sut.PutRestoreAdmin(cipherDetails.Id)); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] public async Task PutRestoreAdmin_WithOwnerOrAdmin_WithAccessToUnassignedCipher_RestoresCipher( - OrganizationUserType organizationUserType, CipherDetails cipher, Guid userId, + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; - cipher.Type = CipherType.Login; - cipher.Data = JsonSerializer.Serialize(new CipherLoginData()); + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Type = CipherType.Login; + cipherDetails.Data = JsonSerializer.Serialize(new CipherLoginData()); organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); sutProvider.GetDependency() .GetManyUnassignedOrganizationDetailsByOrganizationIdAsync(organization.Id) - .Returns(new List { new() { Id = cipher.Id } }); + .Returns(new List { new() { Id = cipherDetails.Id } }); - var result = await sutProvider.Sut.PutRestoreAdmin(cipher.Id.ToString()); + var result = await sutProvider.Sut.PutRestoreAdmin(cipherDetails.Id); - Assert.NotNull(result); Assert.IsType(result); - await sutProvider.GetDependency().Received(1).RestoreAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).RestoreAsync(cipherDetails, userId, true); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] public async Task PutRestoreAdmin_WithOwnerOrAdmin_WithAccessToAllCollectionItems_RestoresCipher( - OrganizationUserType organizationUserType, CipherDetails cipher, Guid userId, + OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; - cipher.Type = CipherType.Login; - cipher.Data = JsonSerializer.Serialize(new CipherLoginData()); + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Type = CipherType.Login; + cipherDetails.Data = JsonSerializer.Serialize(new CipherLoginData()); organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipher }); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); sutProvider.GetDependency().GetOrganizationAbilityAsync(organization.Id).Returns(new OrganizationAbility { Id = organization.Id, AllowAdminAccessToAllCollectionItems = true }); - var result = await sutProvider.Sut.PutRestoreAdmin(cipher.Id.ToString()); + var result = await sutProvider.Sut.PutRestoreAdmin(cipherDetails.Id); - Assert.NotNull(result); - await sutProvider.GetDependency().Received(1).RestoreAsync(cipher, userId, true); + Assert.IsType(result); + await sutProvider.GetDependency().Received(1).RestoreAsync(cipherDetails, userId, true); } [Theory] [BitAutoData] public async Task PutRestoreAdmin_WithCustomUser_WithEditAnyCollectionTrue_RestoresCipher( - CipherDetails cipher, Guid userId, + CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; - cipher.Type = CipherType.Login; - cipher.Data = JsonSerializer.Serialize(new CipherLoginData()); + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Type = CipherType.Login; + cipherDetails.Data = JsonSerializer.Serialize(new CipherLoginData()); organization.Type = OrganizationUserType.Custom; organization.Permissions.EditAnyCollection = true; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipher }); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); - var result = await sutProvider.Sut.PutRestoreAdmin(cipher.Id.ToString()); + var result = await sutProvider.Sut.PutRestoreAdmin(cipherDetails.Id); - Assert.NotNull(result); Assert.IsType(result); - await sutProvider.GetDependency().Received(1).RestoreAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).RestoreAsync(cipherDetails, userId, true); } [Theory] [BitAutoData] public async Task PutRestoreAdmin_WithCustomUser_WithEditAnyCollectionFalse_ThrowsNotFoundException( - CipherDetails cipher, Guid userId, + CipherDetails cipherDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipher.OrganizationId = organization.Id; - cipher.Type = CipherType.Login; - cipher.Data = JsonSerializer.Serialize(new CipherLoginData()); + cipherDetails.OrganizationId = organization.Id; + cipherDetails.Type = CipherType.Login; + cipherDetails.Data = JsonSerializer.Serialize(new CipherLoginData()); organization.Type = OrganizationUserType.Custom; organization.Permissions.EditAnyCollection = false; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipherDetails.Id).Returns(cipherDetails); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipher }); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); - await Assert.ThrowsAsync(() => sutProvider.Sut.PutRestoreAdmin(cipher.Id.ToString())); + await Assert.ThrowsAsync(() => sutProvider.Sut.PutRestoreAdmin(cipherDetails.Id)); } [Theory] [BitAutoData] public async Task PutRestoreAdmin_WithProviderUser_RestoresCipher( - CipherDetails cipher, Guid userId, SutProvider sutProvider) + CipherDetails cipherDetails, Guid userId, SutProvider sutProvider) { - cipher.OrganizationId = Guid.NewGuid(); - cipher.Type = CipherType.Login; - cipher.Data = JsonSerializer.Serialize(new CipherLoginData()); + cipherDetails.OrganizationId = Guid.NewGuid(); + cipherDetails.Type = CipherType.Login; + cipherDetails.Data = JsonSerializer.Serialize(new CipherLoginData()); sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().ProviderUserForOrgAsync(cipher.OrganizationId.Value).Returns(true); - sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipher.Id).Returns(cipher); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(cipher.OrganizationId.Value).Returns(new List { cipher }); + sutProvider.GetDependency().ProviderUserForOrgAsync(cipherDetails.OrganizationId.Value).Returns(true); + sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(cipherDetails.OrganizationId.Value).Returns(new List { cipherDetails }); - var result = await sutProvider.Sut.PutRestoreAdmin(cipher.Id.ToString()); + var result = await sutProvider.Sut.PutRestoreAdmin(cipherDetails.Id); - Assert.NotNull(result); Assert.IsType(result); - await sutProvider.GetDependency().Received(1).RestoreAsync(cipher, userId, true); + await sutProvider.GetDependency().Received(1).RestoreAsync(cipherDetails, userId, true); } [Theory] [BitAutoData] public async Task PutRestoreAdmin_WithProviderUser_WithRestrictProviderAccessTrue_ThrowsNotFoundException( - CipherDetails cipher, Guid userId, SutProvider sutProvider) + CipherDetails cipherDetails, Guid userId, SutProvider sutProvider) { - cipher.OrganizationId = Guid.NewGuid(); + cipherDetails.OrganizationId = Guid.NewGuid(); sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); - sutProvider.GetDependency().ProviderUserForOrgAsync(cipher.OrganizationId.Value).Returns(true); - sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipher.Id).Returns(cipher); + sutProvider.GetDependency().ProviderUserForOrgAsync(cipherDetails.OrganizationId.Value).Returns(true); + sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipherDetails.Id).Returns(cipherDetails); sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.RestrictProviderAccess).Returns(true); - await Assert.ThrowsAsync(() => sutProvider.Sut.PutRestoreAdmin(cipher.Id.ToString())); + await Assert.ThrowsAsync(() => sutProvider.Sut.PutRestoreAdmin(cipherDetails.Id)); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] - public async Task PutRestoreManyAdmin_WithOwnerOrAdmin_WithAccessToSpecificCiphers_RestoresCiphers( + public async Task PutRestoreManyAdmin_WithOwnerOrAdmin_WithEditPermission_RestoresCiphers( OrganizationUserType organizationUserType, CipherBulkRestoreRequestModel model, Guid userId, List ciphers, CurrentContextOrganization organization, SutProvider sutProvider) { @@ -1047,7 +1592,6 @@ public class CiphersControllerTests sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(ciphers); sutProvider.GetDependency() .GetManyByUserIdAsync(userId) .Returns(ciphers.Select(c => new CipherDetails @@ -1071,7 +1615,6 @@ public class CiphersControllerTests var result = await sutProvider.Sut.PutRestoreManyAdmin(model); - Assert.NotNull(result); await sutProvider.GetDependency().Received(1) .RestoreManyAsync( Arg.Is>(ids => @@ -1079,6 +1622,130 @@ public class CiphersControllerTests userId, organization.Id, true); } + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutRestoreManyAdmin_WithOwnerOrAdmin_WithoutEditPermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherBulkRestoreRequestModel model, Guid userId, List ciphers, + CurrentContextOrganization organization, SutProvider sutProvider) + { + model.OrganizationId = organization.Id; + model.Ids = ciphers.Select(c => c.Id.ToString()).ToList(); + organization.Type = organizationUserType; + + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organization.Id, + Edit = false, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new CipherLoginData()) + }).ToList()); + + await Assert.ThrowsAsync(() => sutProvider.Sut.PutRestoreManyAdmin(model)); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutRestoreManyAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithManagePermission_RestoresCiphers( + OrganizationUserType organizationUserType, CipherBulkRestoreRequestModel model, Guid userId, List ciphers, + CurrentContextOrganization organization, SutProvider sutProvider) + { + model.OrganizationId = organization.Id; + model.Ids = ciphers.Select(c => c.Id.ToString()).ToList(); + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organization.Id, + Edit = true, + Manage = true + }).ToList()); + + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + var cipherOrgDetails = ciphers.Select(c => new CipherOrganizationDetails + { + Id = c.Id, + OrganizationId = organization.Id, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new CipherLoginData()) + }).ToList(); + + sutProvider.GetDependency() + .RestoreManyAsync( + Arg.Is>(ids => + ids.All(id => model.Ids.Contains(id.ToString())) && ids.Count == model.Ids.Count()), + userId, organization.Id, true) + .Returns(cipherOrgDetails); + + var result = await sutProvider.Sut.PutRestoreManyAdmin(model); + + Assert.Equal(ciphers.Count, result.Data.Count()); + await sutProvider.GetDependency() + .Received(1) + .RestoreManyAsync( + Arg.Is>(ids => + ids.All(id => model.Ids.Contains(id.ToString())) && ids.Count == model.Ids.Count()), + userId, organization.Id, true); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task PutRestoreManyAdmin_WithLimitItemDeletionEnabled_WithOwnerOrAdmin_WithoutManagePermission_ThrowsNotFoundException( + OrganizationUserType organizationUserType, CipherBulkRestoreRequestModel model, Guid userId, List ciphers, + CurrentContextOrganization organization, SutProvider sutProvider) + { + model.OrganizationId = organization.Id; + model.Ids = ciphers.Select(c => c.Id.ToString()).ToList(); + organization.Type = organizationUserType; + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.LimitItemDeletion).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); + sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); + sutProvider.GetDependency() + .GetManyByUserIdAsync(userId) + .Returns(ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organization.Id, + Edit = true, + Manage = false, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new CipherLoginData()) + }).ToList()); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organization.Id) + .Returns(new OrganizationAbility + { + Id = organization.Id, + LimitItemDeletion = true + }); + + await Assert.ThrowsAsync(() => sutProvider.Sut.PutRestoreManyAdmin(model)); + } + [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] diff --git a/test/Core.Test/Vault/Services/CipherServiceTests.cs b/test/Core.Test/Vault/Services/CipherServiceTests.cs index a7dcbddcea..ed07799c93 100644 --- a/test/Core.Test/Vault/Services/CipherServiceTests.cs +++ b/test/Core.Test/Vault/Services/CipherServiceTests.cs @@ -8,6 +8,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models.Data.Organizations; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -670,7 +671,7 @@ public class CipherServiceTests [Theory] [BitAutoData] - public async Task RestoreAsync_UpdatesUserCipher(Guid restoringUserId, Cipher cipher, SutProvider sutProvider) + public async Task RestoreAsync_UpdatesUserCipher(Guid restoringUserId, CipherDetails cipher, SutProvider sutProvider) { sutProvider.GetDependency().GetCanEditByIdAsync(restoringUserId, cipher.Id).Returns(true); @@ -687,7 +688,7 @@ public class CipherServiceTests [Theory] [OrganizationCipherCustomize] [BitAutoData] - public async Task RestoreAsync_UpdatesOrganizationCipher(Guid restoringUserId, Cipher cipher, SutProvider sutProvider) + public async Task RestoreAsync_UpdatesOrganizationCipher(Guid restoringUserId, CipherDetails cipher, SutProvider sutProvider) { sutProvider.GetDependency().GetCanEditByIdAsync(restoringUserId, cipher.Id).Returns(true); @@ -704,11 +705,11 @@ public class CipherServiceTests [Theory] [BitAutoData] public async Task RestoreAsync_WithAlreadyRestoredCipher_SkipsOperation( - Guid restoringUserId, Cipher cipher, SutProvider sutProvider) + Guid restoringUserId, CipherDetails cipherDetails, SutProvider sutProvider) { - cipher.DeletedDate = null; + cipherDetails.DeletedDate = null; - await sutProvider.Sut.RestoreAsync(cipher, restoringUserId, true); + await sutProvider.Sut.RestoreAsync(cipherDetails, restoringUserId, true); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().UpsertAsync(default); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCipherEventAsync(default, default); @@ -718,13 +719,13 @@ public class CipherServiceTests [Theory] [BitAutoData] public async Task RestoreAsync_WithPersonalCipherBelongingToDifferentUser_ThrowsBadRequestException( - Guid restoringUserId, Cipher cipher, SutProvider sutProvider) + Guid restoringUserId, CipherDetails cipherDetails, SutProvider sutProvider) { - cipher.UserId = Guid.NewGuid(); - cipher.OrganizationId = null; + cipherDetails.UserId = Guid.NewGuid(); + cipherDetails.OrganizationId = null; var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreAsync(cipher, restoringUserId)); + () => sutProvider.Sut.RestoreAsync(cipherDetails, restoringUserId)); Assert.Contains("do not have permissions", exception.Message); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().UpsertAsync(default); @@ -736,14 +737,14 @@ public class CipherServiceTests [OrganizationCipherCustomize] [BitAutoData] public async Task RestoreAsync_WithOrgCipherLackingEditPermission_ThrowsBadRequestException( - Guid restoringUserId, Cipher cipher, SutProvider sutProvider) + Guid restoringUserId, CipherDetails cipherDetails, SutProvider sutProvider) { sutProvider.GetDependency() - .GetCanEditByIdAsync(restoringUserId, cipher.Id) + .GetCanEditByIdAsync(restoringUserId, cipherDetails.Id) .Returns(false); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreAsync(cipher, restoringUserId)); + () => sutProvider.Sut.RestoreAsync(cipherDetails, restoringUserId)); Assert.Contains("do not have permissions", exception.Message); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().UpsertAsync(default); @@ -753,7 +754,7 @@ public class CipherServiceTests [Theory] [BitAutoData] - public async Task RestoreAsync_WithCipherDetailsType_RestoresCipherDetails( + public async Task RestoreAsync_WithEditPermission_RestoresCipherDetails( Guid restoringUserId, CipherDetails cipherDetails, SutProvider sutProvider) { sutProvider.GetDependency() @@ -773,6 +774,91 @@ public class CipherServiceTests await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipherDetails, null); } + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task RestoreAsync_WithOrgAdminOverride_RestoresCipher( + Guid restoringUserId, CipherDetails cipherDetails, SutProvider sutProvider) + { + cipherDetails.DeletedDate = DateTime.UtcNow; + + await sutProvider.Sut.RestoreAsync(cipherDetails, restoringUserId, true); + + Assert.Null(cipherDetails.DeletedDate); + Assert.NotEqual(DateTime.UtcNow, cipherDetails.RevisionDate); + await sutProvider.GetDependency().Received(1).UpsertAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_Restored); + await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipherDetails, null); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task RestoreAsync_WithLimitItemDeletionEnabled_WithManagePermission_RestoresCipher( + Guid restoringUserId, CipherDetails cipherDetails, User user, SutProvider sutProvider) + { + cipherDetails.OrganizationId = Guid.NewGuid(); + cipherDetails.DeletedDate = DateTime.UtcNow; + cipherDetails.Edit = false; + cipherDetails.Manage = true; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetUserByIdAsync(restoringUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(cipherDetails.OrganizationId.Value) + .Returns(new OrganizationAbility + { + Id = cipherDetails.OrganizationId.Value, + LimitItemDeletion = true + }); + + await sutProvider.Sut.RestoreAsync(cipherDetails, restoringUserId); + + Assert.Null(cipherDetails.DeletedDate); + Assert.NotEqual(DateTime.UtcNow, cipherDetails.RevisionDate); + await sutProvider.GetDependency().Received(1).UpsertAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_Restored); + await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipherDetails, null); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task RestoreAsync_WithLimitItemDeletionEnabled_WithoutManagePermission_ThrowsBadRequestException( + Guid restoringUserId, CipherDetails cipherDetails, User user, SutProvider sutProvider) + { + cipherDetails.OrganizationId = Guid.NewGuid(); + cipherDetails.DeletedDate = DateTime.UtcNow; + cipherDetails.Edit = true; + cipherDetails.Manage = false; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetUserByIdAsync(restoringUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(cipherDetails.OrganizationId.Value) + .Returns(new OrganizationAbility + { + Id = cipherDetails.OrganizationId.Value, + LimitItemDeletion = true + }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RestoreAsync(cipherDetails, restoringUserId)); + + Assert.Contains("do not have permissions", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().UpsertAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCipherEventAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().PushSyncCipherUpdateAsync(default, default); + } + [Theory] [BitAutoData] public async Task RestoreManyAsync_UpdatesCiphers(ICollection ciphers, @@ -852,6 +938,239 @@ public class CipherServiceTests await AssertNoActionsAsync(sutProvider); } + [Theory] + [BitAutoData] + public async Task RestoreManyAsync_WithPersonalCipherBelongingToDifferentUser_DoesNotRestoreCiphers( + Guid restoringUserId, List ciphers, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + var differentUserId = Guid.NewGuid(); + + foreach (var cipher in ciphers) + { + cipher.UserId = differentUserId; + cipher.OrganizationId = null; + cipher.DeletedDate = DateTime.UtcNow; + } + + sutProvider.GetDependency() + .GetManyByUserIdAsync(restoringUserId) + .Returns(new List()); + + var result = await sutProvider.Sut.RestoreManyAsync(cipherIds, restoringUserId); + + Assert.Empty(result); + await sutProvider.GetDependency() + .Received(1) + .RestoreAsync(Arg.Is>(ids => !ids.Any()), restoringUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(restoringUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task RestoreManyAsync_WithOrgCipherAndEditPermission_RestoresCiphers( + Guid restoringUserId, List ciphers, Guid organizationId, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + var previousRevisionDate = DateTime.UtcNow; + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + cipher.Edit = true; + cipher.DeletedDate = DateTime.UtcNow; + cipher.RevisionDate = previousRevisionDate; + } + + sutProvider.GetDependency() + .GetManyByUserIdAsync(restoringUserId) + .Returns(ciphers); + + var revisionDate = previousRevisionDate + TimeSpan.FromMinutes(1); + sutProvider.GetDependency() + .RestoreAsync(Arg.Any>(), restoringUserId) + .Returns(revisionDate); + + var result = await sutProvider.Sut.RestoreManyAsync(cipherIds, restoringUserId); + + Assert.Equal(ciphers.Count, result.Count); + foreach (var cipher in result) + { + Assert.Null(cipher.DeletedDate); + Assert.Equal(revisionDate, cipher.RevisionDate); + } + + await sutProvider.GetDependency() + .Received(1) + .RestoreAsync(Arg.Is>(ids => ids.Count() == cipherIds.Count() && + ids.All(id => cipherIds.Contains(id))), restoringUserId); + await sutProvider.GetDependency() + .Received(1) + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(restoringUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task RestoreManyAsync_WithOrgCipherLackingEditPermission_DoesNotRestoreCiphers( + Guid restoringUserId, List ciphers, Guid organizationId, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + var cipherDetailsList = ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organizationId, + Edit = false, + DeletedDate = DateTime.UtcNow + }).ToList(); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(restoringUserId) + .Returns(cipherDetailsList); + + var result = await sutProvider.Sut.RestoreManyAsync(cipherIds, restoringUserId); + + Assert.Empty(result); + await sutProvider.GetDependency() + .Received(1) + .RestoreAsync(Arg.Is>(ids => !ids.Any()), restoringUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(restoringUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task RestoreManyAsync_WithLimitItemDeletionEnabled_WithManagePermission_RestoresCiphers( + Guid restoringUserId, List ciphers, User user, SutProvider sutProvider) + { + var organizationId = Guid.NewGuid(); + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + var previousRevisionDate = DateTime.UtcNow; + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + cipher.Edit = false; + cipher.Manage = true; + cipher.DeletedDate = DateTime.UtcNow; + cipher.RevisionDate = previousRevisionDate; + } + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetManyByUserIdAsync(restoringUserId) + .Returns(ciphers); + sutProvider.GetDependency() + .GetUserByIdAsync(restoringUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { + organizationId, new OrganizationAbility + { + Id = organizationId, + LimitItemDeletion = true + } + } + }); + + var revisionDate = previousRevisionDate + TimeSpan.FromMinutes(1); + sutProvider.GetDependency() + .RestoreAsync(Arg.Any>(), restoringUserId) + .Returns(revisionDate); + + var result = await sutProvider.Sut.RestoreManyAsync(cipherIds, restoringUserId); + + Assert.Equal(ciphers.Count, result.Count); + foreach (var cipher in result) + { + Assert.Null(cipher.DeletedDate); + Assert.Equal(revisionDate, cipher.RevisionDate); + } + + await sutProvider.GetDependency() + .Received(1) + .RestoreAsync(Arg.Is>(ids => ids.Count() == cipherIds.Count() && + ids.All(id => cipherIds.Contains(id))), restoringUserId); + await sutProvider.GetDependency() + .Received(1) + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(restoringUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task RestoreManyAsync_WithLimitItemDeletionEnabled_WithoutManagePermission_DoesNotRestoreCiphers( + Guid restoringUserId, List ciphers, User user, SutProvider sutProvider) + { + var organizationId = Guid.NewGuid(); + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + cipher.Edit = true; + cipher.Manage = false; + cipher.DeletedDate = DateTime.UtcNow; + } + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetManyByUserIdAsync(restoringUserId) + .Returns(ciphers); + sutProvider.GetDependency() + .GetUserByIdAsync(restoringUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { + organizationId, new OrganizationAbility + { + Id = organizationId, + LimitItemDeletion = true + } + } + }); + + var result = await sutProvider.Sut.RestoreManyAsync(cipherIds, restoringUserId); + + Assert.Empty(result); + await sutProvider.GetDependency() + .Received(1) + .RestoreAsync(Arg.Is>(ids => !ids.Any()), restoringUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(restoringUserId); + } + [Theory, BitAutoData] public async Task ShareManyAsync_FreeOrgWithAttachment_Throws(SutProvider sutProvider, IEnumerable ciphers, Guid organizationId, List collectionIds) @@ -1126,47 +1445,47 @@ public class CipherServiceTests [Theory] [BitAutoData] public async Task DeleteAsync_WithPersonalCipherOwner_DeletesCipher( - Guid deletingUserId, Cipher cipher, SutProvider sutProvider) + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) { - cipher.UserId = deletingUserId; - cipher.OrganizationId = null; + cipherDetails.UserId = deletingUserId; + cipherDetails.OrganizationId = null; - await sutProvider.Sut.DeleteAsync(cipher, deletingUserId); + await sutProvider.Sut.DeleteAsync(cipherDetails, deletingUserId); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipher); - await sutProvider.GetDependency().Received(1).DeleteAttachmentsForCipherAsync(cipher.Id); - await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipher, EventType.Cipher_Deleted); - await sutProvider.GetDependency().Received(1).PushSyncCipherDeleteAsync(cipher); + await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).DeleteAttachmentsForCipherAsync(cipherDetails.Id); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_Deleted); + await sutProvider.GetDependency().Received(1).PushSyncCipherDeleteAsync(cipherDetails); } [Theory] [OrganizationCipherCustomize] [BitAutoData] public async Task DeleteAsync_WithOrgCipherAndEditPermission_DeletesCipher( - Guid deletingUserId, Cipher cipher, SutProvider sutProvider) + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) { sutProvider.GetDependency() - .GetCanEditByIdAsync(deletingUserId, cipher.Id) + .GetCanEditByIdAsync(deletingUserId, cipherDetails.Id) .Returns(true); - await sutProvider.Sut.DeleteAsync(cipher, deletingUserId); + await sutProvider.Sut.DeleteAsync(cipherDetails, deletingUserId); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipher); - await sutProvider.GetDependency().Received(1).DeleteAttachmentsForCipherAsync(cipher.Id); - await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipher, EventType.Cipher_Deleted); - await sutProvider.GetDependency().Received(1).PushSyncCipherDeleteAsync(cipher); + await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).DeleteAttachmentsForCipherAsync(cipherDetails.Id); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_Deleted); + await sutProvider.GetDependency().Received(1).PushSyncCipherDeleteAsync(cipherDetails); } [Theory] [BitAutoData] public async Task DeleteAsync_WithPersonalCipherBelongingToDifferentUser_ThrowsBadRequestException( - Guid deletingUserId, Cipher cipher, SutProvider sutProvider) + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) { - cipher.UserId = Guid.NewGuid(); - cipher.OrganizationId = null; + cipherDetails.UserId = Guid.NewGuid(); + cipherDetails.OrganizationId = null; var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteAsync(cipher, deletingUserId)); + () => sutProvider.Sut.DeleteAsync(cipherDetails, deletingUserId)); Assert.Contains("do not have permissions", exception.Message); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().DeleteAsync(default); @@ -1179,14 +1498,14 @@ public class CipherServiceTests [OrganizationCipherCustomize] [BitAutoData] public async Task DeleteAsync_WithOrgCipherLackingEditPermission_ThrowsBadRequestException( - Guid deletingUserId, Cipher cipher, SutProvider sutProvider) + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) { sutProvider.GetDependency() - .GetCanEditByIdAsync(deletingUserId, cipher.Id) + .GetCanEditByIdAsync(deletingUserId, cipherDetails.Id) .Returns(false); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteAsync(cipher, deletingUserId)); + () => sutProvider.Sut.DeleteAsync(cipherDetails, deletingUserId)); Assert.Contains("do not have permissions", exception.Message); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().DeleteAsync(default); @@ -1196,62 +1515,400 @@ public class CipherServiceTests } [Theory] + [OrganizationCipherCustomize] [BitAutoData] - public async Task SoftDeleteAsync_WithPersonalCipherOwner_SoftDeletesCipher( - Guid deletingUserId, Cipher cipher, SutProvider sutProvider) + public async Task DeleteAsync_WithOrgAdminOverride_DeletesCipher( + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) { - cipher.UserId = deletingUserId; - cipher.OrganizationId = null; - cipher.DeletedDate = null; + await sutProvider.Sut.DeleteAsync(cipherDetails, deletingUserId, true); + + await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).DeleteAttachmentsForCipherAsync(cipherDetails.Id); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_Deleted); + await sutProvider.GetDependency().Received(1).PushSyncCipherDeleteAsync(cipherDetails); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task DeleteAsync_WithLimitItemDeletionEnabled_WithManagePermission_DeletesCipher( + Guid deletingUserId, CipherDetails cipherDetails, User user, SutProvider sutProvider) + { + cipherDetails.OrganizationId = Guid.NewGuid(); + cipherDetails.Edit = false; + cipherDetails.Manage = true; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetUserByIdAsync(deletingUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(cipherDetails.OrganizationId.Value) + .Returns(new OrganizationAbility + { + Id = cipherDetails.OrganizationId.Value, + LimitItemDeletion = true + }); + + await sutProvider.Sut.DeleteAsync(cipherDetails, deletingUserId); + + await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).DeleteAttachmentsForCipherAsync(cipherDetails.Id); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_Deleted); + await sutProvider.GetDependency().Received(1).PushSyncCipherDeleteAsync(cipherDetails); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task DeleteAsync_WithLimitItemDeletionEnabled_WithoutManagePermission_ThrowsBadRequestException( + Guid deletingUserId, CipherDetails cipherDetails, User user, SutProvider sutProvider) + { + cipherDetails.OrganizationId = Guid.NewGuid(); + cipherDetails.Edit = true; + cipherDetails.Manage = false; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetUserByIdAsync(deletingUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(cipherDetails.OrganizationId.Value) + .Returns(new OrganizationAbility + { + Id = cipherDetails.OrganizationId.Value, + LimitItemDeletion = true + }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(cipherDetails, deletingUserId)); + + Assert.Contains("do not have permissions", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().DeleteAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().DeleteAttachmentsForCipherAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCipherEventAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().PushSyncCipherDeleteAsync(default); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task DeleteManyAsync_WithOrgAdminOverride_DeletesCiphers( + Guid deletingUserId, List ciphers, Guid organizationId, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + } sutProvider.GetDependency() - .GetCanEditByIdAsync(deletingUserId, cipher.Id) + .GetManyByOrganizationIdAsync(organizationId) + .Returns(ciphers); + + await sutProvider.Sut.DeleteManyAsync(cipherIds, deletingUserId, organizationId, true); + + await sutProvider.GetDependency() + .Received(1) + .DeleteByIdsOrganizationIdAsync(Arg.Is>(ids => ids.Count() == cipherIds.Count() && + ids.All(id => cipherIds.Contains(id))), organizationId); + await sutProvider.GetDependency() + .Received(1) + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [BitAutoData] + public async Task DeleteManyAsync_WithPersonalCipherOwner_DeletesCiphers( + Guid deletingUserId, List ciphers, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + + foreach (var cipher in ciphers) + { + cipher.UserId = deletingUserId; + cipher.OrganizationId = null; + cipher.Edit = true; + } + + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(ciphers); + + await sutProvider.Sut.DeleteManyAsync(cipherIds, deletingUserId); + + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(Arg.Is>(ids => ids.Count() == cipherIds.Count() && + ids.All(id => cipherIds.Contains(id))), deletingUserId); + await sutProvider.GetDependency() + .Received(1) + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [BitAutoData] + public async Task DeleteManyAsync_WithPersonalCipherBelongingToDifferentUser_DoesNotDeleteCiphers( + Guid deletingUserId, List ciphers, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + var differentUserId = Guid.NewGuid(); + + foreach (var cipher in ciphers) + { + cipher.UserId = differentUserId; + cipher.OrganizationId = null; + } + + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(new List()); + + await sutProvider.Sut.DeleteManyAsync(cipherIds, deletingUserId); + + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(Arg.Is>(ids => !ids.Any()), deletingUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task DeleteManyAsync_WithOrgCipherAndEditPermission_DeletesCiphers( + Guid deletingUserId, List ciphers, Guid organizationId, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + cipher.Edit = true; + } + + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(ciphers); + + await sutProvider.Sut.DeleteManyAsync(cipherIds, deletingUserId, organizationId); + + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(Arg.Is>(ids => ids.Count() == cipherIds.Count() && ids.All(id => cipherIds.Contains(id))), deletingUserId); + await sutProvider.GetDependency() + .Received(1) + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task DeleteManyAsync_WithOrgCipherLackingEditPermission_DoesNotDeleteCiphers( + Guid deletingUserId, List ciphers, Guid organizationId, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + var cipherDetailsList = ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organizationId, + Edit = false + }).ToList(); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(cipherDetailsList); + + await sutProvider.Sut.DeleteManyAsync(cipherIds, deletingUserId, organizationId); + + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(Arg.Is>(ids => !ids.Any()), deletingUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task DeleteManyAsync_WithLimitItemDeletionEnabled_WithoutManagePermission_DoesNotDeleteCiphers( + Guid deletingUserId, List ciphers, User user, SutProvider sutProvider) + { + var organizationId = Guid.NewGuid(); + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + cipher.Edit = true; + cipher.Manage = false; + } + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(ciphers); + sutProvider.GetDependency() + .GetUserByIdAsync(deletingUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { + organizationId, new OrganizationAbility + { + Id = organizationId, + LimitItemDeletion = true + } + } + }); + + await sutProvider.Sut.DeleteManyAsync(cipherIds, deletingUserId, organizationId); + + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(Arg.Is>(ids => !ids.Any()), deletingUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task DeleteManyAsync_WithLimitItemDeletionEnabled_WithManagePermission_DeletesCiphers( + Guid deletingUserId, List ciphers, User user, SutProvider sutProvider) + { + var organizationId = Guid.NewGuid(); + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + cipher.Edit = false; + cipher.Manage = true; + } + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(ciphers); + sutProvider.GetDependency() + .GetUserByIdAsync(deletingUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { + organizationId, new OrganizationAbility + { + Id = organizationId, + LimitItemDeletion = true + } + } + }); + + await sutProvider.Sut.DeleteManyAsync(cipherIds, deletingUserId, organizationId); + + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(Arg.Is>(ids => ids.Count() == cipherIds.Count() && + ids.All(id => cipherIds.Contains(id))), deletingUserId); + await sutProvider.GetDependency() + .Received(1) + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [BitAutoData] + public async Task SoftDeleteAsync_WithPersonalCipherOwner_SoftDeletesCipher( + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) + { + cipherDetails.UserId = deletingUserId; + cipherDetails.OrganizationId = null; + cipherDetails.DeletedDate = null; + + sutProvider.GetDependency() + .GetCanEditByIdAsync(deletingUserId, cipherDetails.Id) .Returns(true); - await sutProvider.Sut.SoftDeleteAsync(cipher, deletingUserId); + await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId); - Assert.NotNull(cipher.DeletedDate); - Assert.Equal(cipher.RevisionDate, cipher.DeletedDate); - await sutProvider.GetDependency().Received(1).UpsertAsync(cipher); - await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipher, EventType.Cipher_SoftDeleted); - await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipher, null); + Assert.NotNull(cipherDetails.DeletedDate); + Assert.Equal(cipherDetails.RevisionDate, cipherDetails.DeletedDate); + await sutProvider.GetDependency().Received(1).UpsertAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted); + await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipherDetails, null); } [Theory] [OrganizationCipherCustomize] [BitAutoData] public async Task SoftDeleteAsync_WithOrgCipherAndEditPermission_SoftDeletesCipher( - Guid deletingUserId, Cipher cipher, SutProvider sutProvider) + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) { - cipher.DeletedDate = null; + cipherDetails.DeletedDate = null; sutProvider.GetDependency() - .GetCanEditByIdAsync(deletingUserId, cipher.Id) + .GetCanEditByIdAsync(deletingUserId, cipherDetails.Id) .Returns(true); - await sutProvider.Sut.SoftDeleteAsync(cipher, deletingUserId); + await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId); - Assert.NotNull(cipher.DeletedDate); - Assert.Equal(cipher.DeletedDate, cipher.RevisionDate); - await sutProvider.GetDependency().Received(1).UpsertAsync(cipher); - await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipher, EventType.Cipher_SoftDeleted); - await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipher, null); + Assert.NotNull(cipherDetails.DeletedDate); + Assert.Equal(cipherDetails.RevisionDate, cipherDetails.DeletedDate); + await sutProvider.GetDependency().Received(1).UpsertAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted); + await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipherDetails, null); } [Theory] [BitAutoData] public async Task SoftDeleteAsync_WithPersonalCipherBelongingToDifferentUser_ThrowsBadRequestException( - Guid deletingUserId, Cipher cipher, SutProvider sutProvider) + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) { - cipher.UserId = Guid.NewGuid(); - cipher.OrganizationId = null; + cipherDetails.UserId = Guid.NewGuid(); + cipherDetails.OrganizationId = null; sutProvider.GetDependency() - .GetCanEditByIdAsync(deletingUserId, cipher.Id) + .GetCanEditByIdAsync(deletingUserId, cipherDetails.Id) .Returns(false); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SoftDeleteAsync(cipher, deletingUserId)); + () => sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId)); Assert.Contains("do not have permissions", exception.Message); } @@ -1260,51 +1917,395 @@ public class CipherServiceTests [OrganizationCipherCustomize] [BitAutoData] public async Task SoftDeleteAsync_WithOrgCipherLackingEditPermission_ThrowsBadRequestException( - Guid deletingUserId, Cipher cipher, SutProvider sutProvider) + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) { sutProvider.GetDependency() - .GetCanEditByIdAsync(deletingUserId, cipher.Id) + .GetCanEditByIdAsync(deletingUserId, cipherDetails.Id) .Returns(false); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SoftDeleteAsync(cipher, deletingUserId)); + () => sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId)); Assert.Contains("do not have permissions", exception.Message); } [Theory] [BitAutoData] - public async Task SoftDeleteAsync_WithCipherDetailsType_SoftDeletesCipherDetails( - Guid deletingUserId, CipherDetails cipher, SutProvider sutProvider) + public async Task SoftDeleteAsync_WithEditPermission_SoftDeletesCipherDetails( + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) { - cipher.DeletedDate = null; + cipherDetails.DeletedDate = null; - await sutProvider.Sut.SoftDeleteAsync(cipher, deletingUserId, true); + await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId, true); - Assert.NotNull(cipher.DeletedDate); - Assert.Equal(cipher.DeletedDate, cipher.RevisionDate); - await sutProvider.GetDependency().Received(1).UpsertAsync(cipher); - await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipher, EventType.Cipher_SoftDeleted); - await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipher, null); + Assert.NotNull(cipherDetails.DeletedDate); + Assert.Equal(cipherDetails.RevisionDate, cipherDetails.DeletedDate); + await sutProvider.GetDependency().Received(1).UpsertAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted); + await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipherDetails, null); } [Theory] [BitAutoData] public async Task SoftDeleteAsync_WithAlreadySoftDeletedCipher_SkipsOperation( - Guid deletingUserId, Cipher cipher, SutProvider sutProvider) + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) { sutProvider.GetDependency() - .GetCanEditByIdAsync(deletingUserId, cipher.Id) + .GetCanEditByIdAsync(deletingUserId, cipherDetails.Id) .Returns(true); - cipher.DeletedDate = DateTime.UtcNow.AddDays(-1); + cipherDetails.DeletedDate = DateTime.UtcNow.AddDays(-1); - await sutProvider.Sut.SoftDeleteAsync(cipher, deletingUserId); + await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId); await sutProvider.GetDependency().DidNotReceive().UpsertAsync(Arg.Any()); await sutProvider.GetDependency().DidNotReceive().LogCipherEventAsync(Arg.Any(), Arg.Any()); await sutProvider.GetDependency().DidNotReceive().PushSyncCipherUpdateAsync(Arg.Any(), Arg.Any>()); } + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task SoftDeleteAsync_WithOrgAdminOverride_SoftDeletesCipher( + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) + { + cipherDetails.DeletedDate = null; + + await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId, true); + + await sutProvider.GetDependency().Received(1).UpsertAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted); + await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipherDetails, null); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task SoftDeleteAsync_WithLimitItemDeletionEnabled_WithManagePermission_SoftDeletesCipher( + Guid deletingUserId, CipherDetails cipherDetails, User user, SutProvider sutProvider) + { + cipherDetails.OrganizationId = Guid.NewGuid(); + cipherDetails.DeletedDate = null; + cipherDetails.Edit = false; + cipherDetails.Manage = true; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetUserByIdAsync(deletingUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(cipherDetails.OrganizationId.Value) + .Returns(new OrganizationAbility + { + Id = cipherDetails.OrganizationId.Value, + LimitItemDeletion = true + }); + + await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId); + + Assert.NotNull(cipherDetails.DeletedDate); + Assert.Equal(cipherDetails.RevisionDate, cipherDetails.DeletedDate); + await sutProvider.GetDependency().Received(1).UpsertAsync(cipherDetails); + await sutProvider.GetDependency().Received(1).LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted); + await sutProvider.GetDependency().Received(1).PushSyncCipherUpdateAsync(cipherDetails, null); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task SoftDeleteAsync_WithLimitItemDeletionEnabled_WithoutManagePermission_ThrowsBadRequestException( + Guid deletingUserId, CipherDetails cipherDetails, User user, SutProvider sutProvider) + { + cipherDetails.OrganizationId = Guid.NewGuid(); + cipherDetails.DeletedDate = null; + cipherDetails.Edit = true; + cipherDetails.Manage = false; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetUserByIdAsync(deletingUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(cipherDetails.OrganizationId.Value) + .Returns(new OrganizationAbility + { + Id = cipherDetails.OrganizationId.Value, + LimitItemDeletion = true + }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId)); + + Assert.Contains("do not have permissions", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().UpsertAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCipherEventAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().PushSyncCipherUpdateAsync(default, default); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task SoftDeleteManyAsync_WithOrgAdminOverride_SoftDeletesCiphers( + Guid deletingUserId, List ciphers, Guid organizationId, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + } + + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(organizationId) + .Returns(ciphers); + + await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, organizationId, true); + + await sutProvider.GetDependency() + .Received(1) + .SoftDeleteByIdsOrganizationIdAsync(Arg.Is>(ids => ids.Count() == cipherIds.Count() && + ids.All(id => cipherIds.Contains(id))), organizationId); + await sutProvider.GetDependency() + .Received(1) + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [BitAutoData] + public async Task SoftDeleteManyAsync_WithPersonalCipherOwner_SoftDeletesCiphers( + Guid deletingUserId, List ciphers, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + + foreach (var cipher in ciphers) + { + cipher.UserId = deletingUserId; + cipher.OrganizationId = null; + cipher.Edit = true; + cipher.DeletedDate = null; + } + + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(ciphers); + + await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, null, false); + + await sutProvider.GetDependency() + .Received(1) + .SoftDeleteAsync(Arg.Is>(ids => ids.Count() == cipherIds.Count() && + ids.All(id => cipherIds.Contains(id))), deletingUserId); + await sutProvider.GetDependency() + .Received(1) + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [BitAutoData] + public async Task SoftDeleteManyAsync_WithPersonalCipherBelongingToDifferentUser_DoesNotDeleteCiphers( + Guid deletingUserId, List ciphers, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + var differentUserId = Guid.NewGuid(); + + foreach (var cipher in ciphers) + { + cipher.UserId = differentUserId; + cipher.OrganizationId = null; + } + + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(new List()); + + await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, null, false); + + await sutProvider.GetDependency() + .Received(1) + .SoftDeleteAsync(Arg.Is>(ids => !ids.Any()), deletingUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task SoftDeleteManyAsync_WithOrgCipherAndEditPermission_SoftDeletesCiphers( + Guid deletingUserId, List ciphers, Guid organizationId, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + cipher.Edit = true; + cipher.DeletedDate = null; + } + + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(ciphers); + + await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, organizationId, false); + + await sutProvider.GetDependency() + .Received(1) + .SoftDeleteAsync(Arg.Is>(ids => ids.Count() == cipherIds.Count() && ids.All(id => cipherIds.Contains(id))), deletingUserId); + await sutProvider.GetDependency() + .Received(1) + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task SoftDeleteManyAsync_WithOrgCipherLackingEditPermission_DoesNotDeleteCiphers( + Guid deletingUserId, List ciphers, Guid organizationId, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + var cipherDetailsList = ciphers.Select(c => new CipherDetails + { + Id = c.Id, + OrganizationId = organizationId, + Edit = false + }).ToList(); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(cipherDetailsList); + + await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, organizationId, false); + + await sutProvider.GetDependency() + .Received(1) + .SoftDeleteAsync(Arg.Is>(ids => !ids.Any()), deletingUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task SoftDeleteManyAsync_WithLimitItemDeletionEnabled_WithoutManagePermission_DoesNotDeleteCiphers( + Guid deletingUserId, List ciphers, User user, SutProvider sutProvider) + { + var organizationId = Guid.NewGuid(); + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + cipher.Edit = true; + cipher.Manage = false; + } + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(ciphers); + sutProvider.GetDependency() + .GetUserByIdAsync(deletingUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { + organizationId, new OrganizationAbility + { + Id = organizationId, + LimitItemDeletion = true + } + } + }); + + await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, organizationId, false); + + await sutProvider.GetDependency() + .Received(1) + .SoftDeleteAsync(Arg.Is>(ids => !ids.Any()), deletingUserId); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + + [Theory] + [OrganizationCipherCustomize] + [BitAutoData] + public async Task SoftDeleteManyAsync_WithLimitItemDeletionEnabled_WithManagePermission_SoftDeletesCiphers( + Guid deletingUserId, List ciphers, User user, SutProvider sutProvider) + { + var organizationId = Guid.NewGuid(); + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organizationId; + cipher.Edit = false; + cipher.Manage = true; + cipher.DeletedDate = null; + } + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.LimitItemDeletion) + .Returns(true); + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(ciphers); + sutProvider.GetDependency() + .GetUserByIdAsync(deletingUserId) + .Returns(user); + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { + organizationId, new OrganizationAbility + { + Id = organizationId, + LimitItemDeletion = true + } + } + }); + + await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, organizationId, false); + + await sutProvider.GetDependency() + .Received(1) + .SoftDeleteAsync(Arg.Is>(ids => ids.Count() == cipherIds.Count() && + ids.All(id => cipherIds.Contains(id))), deletingUserId); + await sutProvider.GetDependency() + .Received(1) + .LogCipherEventsAsync(Arg.Any>>()); + await sutProvider.GetDependency() + .Received(1) + .PushSyncCiphersAsync(deletingUserId); + } + private async Task AssertNoActionsAsync(SutProvider sutProvider) { await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyOrganizationDetailsByOrganizationIdAsync(default); From 10ea2cb3eb417077fcff86b5a733ce559e8588f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Wed, 2 Apr 2025 11:47:44 +0100 Subject: [PATCH 02/15] [PM-17473] Refactor AuthRequestService to remove admin notification feature flag (#5549) --- .../Implementations/AuthRequestService.cs | 6 -- .../Auth/Services/AuthRequestServiceTests.cs | 77 +------------------ 2 files changed, 1 insertion(+), 82 deletions(-) diff --git a/src/Core/Auth/Services/Implementations/AuthRequestService.cs b/src/Core/Auth/Services/Implementations/AuthRequestService.cs index c10fa6ce92..42d51a88f5 100644 --- a/src/Core/Auth/Services/Implementations/AuthRequestService.cs +++ b/src/Core/Auth/Services/Implementations/AuthRequestService.cs @@ -287,12 +287,6 @@ public class AuthRequestService : IAuthRequestService private async Task NotifyAdminsOfDeviceApprovalRequestAsync(OrganizationUser organizationUser, User user) { - if (!_featureService.IsEnabled(FeatureFlagKeys.DeviceApprovalRequestAdminNotifications)) - { - _logger.LogWarning("Skipped sending device approval notification to admins - feature flag disabled"); - return; - } - var adminEmails = await GetAdminAndAccountRecoveryEmailsAsync(organizationUser.OrganizationId); await _mailService.SendDeviceApprovalRequestedNotificationEmailAsync( diff --git a/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs b/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs index 5e99ecf171..edd7a06fa7 100644 --- a/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs +++ b/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs @@ -273,78 +273,7 @@ public class AuthRequestServiceTests /// each of them. /// [Theory, BitAutoData] - public async Task CreateAuthRequestAsync_AdminApproval_CreatesForEachOrganization( - SutProvider sutProvider, - AuthRequestCreateRequestModel createModel, - User user, - OrganizationUser organizationUser1, - OrganizationUser organizationUser2) - { - createModel.Type = AuthRequestType.AdminApproval; - user.Email = createModel.Email; - organizationUser1.UserId = user.Id; - organizationUser2.UserId = user.Id; - - sutProvider.GetDependency() - .GetByEmailAsync(user.Email) - .Returns(user); - - sutProvider.GetDependency() - .DeviceType - .Returns(DeviceType.ChromeExtension); - - sutProvider.GetDependency() - .UserId - .Returns(user.Id); - - sutProvider.GetDependency() - .PasswordlessAuth.KnownDevicesOnly - .Returns(false); - - - sutProvider.GetDependency() - .GetManyByUserAsync(user.Id) - .Returns(new List - { - organizationUser1, - organizationUser2, - }); - - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(c => c.ArgAt(0)); - - var authRequest = await sutProvider.Sut.CreateAuthRequestAsync(createModel); - - Assert.Equal(organizationUser1.OrganizationId, authRequest.OrganizationId); - - await sutProvider.GetDependency() - .Received(1) - .CreateAsync(Arg.Is(o => o.OrganizationId == organizationUser1.OrganizationId)); - - await sutProvider.GetDependency() - .Received(1) - .CreateAsync(Arg.Is(o => o.OrganizationId == organizationUser2.OrganizationId)); - - await sutProvider.GetDependency() - .Received(2) - .CreateAsync(Arg.Any()); - - await sutProvider.GetDependency() - .Received(1) - .LogUserEventAsync(user.Id, EventType.User_RequestedDeviceApproval); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendDeviceApprovalRequestedNotificationEmailAsync( - Arg.Any>(), - Arg.Any(), - Arg.Any(), - Arg.Any()); - } - - [Theory, BitAutoData] - public async Task CreateAuthRequestAsync_AdminApproval_WithAdminNotifications_CreatesForEachOrganization_SendsEmails( + public async Task CreateAuthRequestAsync_AdminApproval_CreatesForEachOrganization_SendsEmails( SutProvider sutProvider, AuthRequestCreateRequestModel createModel, User user, @@ -369,10 +298,6 @@ public class AuthRequestServiceTests ManageResetPassword = true, }); - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.DeviceApprovalRequestAdminNotifications) - .Returns(true); - sutProvider.GetDependency() .GetByEmailAsync(user.Email) .Returns(user); From b309de141d52a38d81e0b329ca41b64cd00df01f Mon Sep 17 00:00:00 2001 From: Jonas Hendrickx Date: Wed, 2 Apr 2025 19:47:48 +0200 Subject: [PATCH 03/15] [PM-19147] Automatic Tax Improvements (#5545) * Pm 19147 2 (#5544) * Pm 19147 2 (#5544) * Unit tests for tax strategies `GetUpdateOptions` * Only allow automatic tax flag to be updated for complete subscription updates such as plan changes, not when upgrading additional storage, seats, etc * unit tests for factory * Fix build * Automatic tax for tax estimation * Fix stub * Fix stub * "customer.tax_ids" isn't expanded in some flows. * Fix SubscriberServiceTests.cs * BusinessUseAutomaticTaxStrategy > SetUpdateOptions tests * Fix ProviderBillingServiceTests.cs --- .../RemoveOrganizationFromProviderCommand.cs | 28 +- .../Billing/ProviderBillingService.cs | 23 +- ...oveOrganizationFromProviderCommandTests.cs | 20 + .../Billing/ProviderBillingServiceTests.cs | 35 +- .../Implementations/UpcomingInvoiceHandler.cs | 24 +- src/Core/Billing/Constants/StripeConstants.cs | 2 + .../Billing/Extensions/CustomerExtensions.cs | 2 +- .../Extensions/ServiceCollectionExtensions.cs | 4 + .../SubscriptionCreateOptionsExtensions.cs | 26 - .../AutomaticTaxFactoryParameters.cs | 30 ++ .../Billing/Services/IAutomaticTaxFactory.cs | 11 + .../Billing/Services/IAutomaticTaxStrategy.cs | 33 ++ .../AutomaticTax/AutomaticTaxFactory.cs | 50 ++ .../BusinessUseAutomaticTaxStrategy.cs | 96 ++++ .../PersonalUseAutomaticTaxStrategy.cs | 64 +++ .../OrganizationBillingService.cs | 33 +- .../PremiumUserBillingService.cs | 22 +- .../Implementations/SubscriberService.cs | 54 +- src/Core/Constants.cs | 3 + .../Implementations/StripePaymentService.cs | 92 +++- .../BusinessUseAutomaticTaxStrategyTests.cs | 492 ++++++++++++++++++ .../PersonalUseAutomaticTaxStrategyTests.cs | 217 ++++++++ .../AutomaticTaxFactoryTests.cs | 105 ++++ .../Services/SubscriberServiceTests.cs | 55 +- .../Billing/Stubs/FakeAutomaticTaxStrategy.cs | 35 ++ 25 files changed, 1448 insertions(+), 108 deletions(-) delete mode 100644 src/Core/Billing/Extensions/SubscriptionCreateOptionsExtensions.cs create mode 100644 src/Core/Billing/Services/Contracts/AutomaticTaxFactoryParameters.cs create mode 100644 src/Core/Billing/Services/IAutomaticTaxFactory.cs create mode 100644 src/Core/Billing/Services/IAutomaticTaxStrategy.cs create mode 100644 src/Core/Billing/Services/Implementations/AutomaticTax/AutomaticTaxFactory.cs create mode 100644 src/Core/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategy.cs create mode 100644 src/Core/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategy.cs create mode 100644 test/Core.Test/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategyTests.cs create mode 100644 test/Core.Test/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategyTests.cs create mode 100644 test/Core.Test/Billing/Services/Implementations/AutomaticTaxFactoryTests.cs create mode 100644 test/Core.Test/Billing/Stubs/FakeAutomaticTaxStrategy.cs diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs index d2acdac079..2c34e57a92 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs @@ -1,4 +1,5 @@ -using Bit.Core.AdminConsole.Entities; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.Providers.Interfaces; @@ -7,10 +8,12 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; +using Microsoft.Extensions.DependencyInjection; using Stripe; namespace Bit.Commercial.Core.AdminConsole.Providers; @@ -28,6 +31,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv private readonly ISubscriberService _subscriberService; private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery; private readonly IPricingClient _pricingClient; + private readonly IAutomaticTaxStrategy _automaticTaxStrategy; public RemoveOrganizationFromProviderCommand( IEventService eventService, @@ -40,7 +44,8 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv IProviderBillingService providerBillingService, ISubscriberService subscriberService, IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, - IPricingClient pricingClient) + IPricingClient pricingClient, + [FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy) { _eventService = eventService; _mailService = mailService; @@ -53,6 +58,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv _subscriberService = subscriberService; _hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery; _pricingClient = pricingClient; + _automaticTaxStrategy = automaticTaxStrategy; } public async Task RemoveOrganizationFromProvider( @@ -107,10 +113,11 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv organization.IsValidClient() && !string.IsNullOrEmpty(organization.GatewayCustomerId)) { - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + var customer = await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Description = string.Empty, - Email = organization.BillingEmail + Email = organization.BillingEmail, + Expand = ["tax", "tax_ids"] }); var plan = await _pricingClient.GetPlanOrThrow(organization.PlanType); @@ -120,7 +127,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv Customer = organization.GatewayCustomerId, CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, DaysUntilDue = 30, - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, Metadata = new Dictionary { { "organizationId", organization.Id.ToString() } @@ -130,6 +136,18 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }] }; + if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + _automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + } + else + { + subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + } + var subscription = await _stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); organization.GatewaySubscriptionId = subscription.Id; diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index 74cfc1f916..65e41ab586 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -14,6 +14,7 @@ using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Contracts; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -22,6 +23,7 @@ using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using CsvHelper; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; @@ -29,6 +31,7 @@ namespace Bit.Commercial.Core.Billing; public class ProviderBillingService( IEventService eventService, + IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, IOrganizationRepository organizationRepository, @@ -40,7 +43,9 @@ public class ProviderBillingService( IProviderUserRepository providerUserRepository, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - ITaxService taxService) : IProviderBillingService + ITaxService taxService, + [FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy) + : IProviderBillingService { [RequireFeature(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal)] public async Task AddExistingOrganization( @@ -557,7 +562,8 @@ public class ProviderBillingService( { ArgumentNullException.ThrowIfNull(provider); - var customer = await subscriberService.GetCustomerOrThrow(provider); + var customerGetOptions = new CustomerGetOptions { Expand = ["tax", "tax_ids"] }; + var customer = await subscriberService.GetCustomerOrThrow(provider, customerGetOptions); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); @@ -589,10 +595,6 @@ public class ProviderBillingService( var subscriptionCreateOptions = new SubscriptionCreateOptions { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }, CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, Customer = customer.Id, DaysUntilDue = 30, @@ -605,6 +607,15 @@ public class ProviderBillingService( ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations }; + if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + } + else + { + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; + } + try { var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index 2debd521a5..48eda094e8 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -228,6 +228,26 @@ public class RemoveOrganizationFromProviderCommandTests Id = "subscription_id" }); + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( + Arg.Is(options => + options.Customer == organization.GatewayCustomerId && + options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && + options.DaysUntilDue == 30 && + options.Metadata["organizationId"] == organization.Id.ToString() && + options.OffSession == true && + options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId && + options.Items.First().Quantity == organization.Seats) + , Arg.Any())) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index c1da732d60..71a150a546 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -924,11 +924,15 @@ public class ProviderBillingServiceTests { provider.GatewaySubscriptionId = null; - sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(new Customer - { - Id = "customer_id", - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } - }); + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) + .Returns(new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }); var providerPlans = new List { @@ -975,11 +979,15 @@ public class ProviderBillingServiceTests { provider.GatewaySubscriptionId = null; - sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(new Customer + var customer = new Customer { Id = "customer_id", Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } - }); + }; + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer); var providerPlans = new List { @@ -1017,6 +1025,19 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( + Arg.Is(options => + options.Customer == "customer_id") + , Arg.Is(p => p == customer))) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index d37bf41428..f75cbf8a8b 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -1,8 +1,11 @@ -using Bit.Core.AdminConsole.Repositories; +using Bit.Core; +using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -12,6 +15,7 @@ using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; public class UpcomingInvoiceHandler( + IFeatureService featureService, ILogger logger, IMailService mailService, IOrganizationRepository organizationRepository, @@ -21,7 +25,8 @@ public class UpcomingInvoiceHandler( IStripeEventService stripeEventService, IStripeEventUtilityService stripeEventUtilityService, IUserRepository userRepository, - IValidateSponsorshipCommand validateSponsorshipCommand) + IValidateSponsorshipCommand validateSponsorshipCommand, + IAutomaticTaxFactory automaticTaxFactory) : IUpcomingInvoiceHandler { public async Task HandleAsync(Event parsedEvent) @@ -136,6 +141,21 @@ public class UpcomingInvoiceHandler( private async Task TryEnableAutomaticTaxAsync(Subscription subscription) { + if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id)); + var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); + var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription); + + if (updateOptions == null) + { + return; + } + + await stripeFacade.UpdateSubscription(subscription.Id, updateOptions); + return; + } + if (subscription.AutomaticTax.Enabled || !subscription.Customer.HasBillingLocation() || await IsNonTaxableNonUSBusinessUseSubscription(subscription)) diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index 080416e2bb..326023e34c 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -47,6 +47,8 @@ public static class StripeConstants public static class MetadataKeys { public const string OrganizationId = "organizationId"; + public const string ProviderId = "providerId"; + public const string UserId = "userId"; } public static class PaymentBehavior diff --git a/src/Core/Billing/Extensions/CustomerExtensions.cs b/src/Core/Billing/Extensions/CustomerExtensions.cs index 1ab595342e..8f15f61a7f 100644 --- a/src/Core/Billing/Extensions/CustomerExtensions.cs +++ b/src/Core/Billing/Extensions/CustomerExtensions.cs @@ -21,7 +21,7 @@ public static class CustomerExtensions /// /// public static bool HasTaxLocationVerified(this Customer customer) => - customer?.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; + customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation; public static decimal GetBillingBalance(this Customer customer) { diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 26815d7df0..17285e0676 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -4,6 +4,7 @@ using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; namespace Bit.Core.Billing.Extensions; @@ -18,6 +19,9 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddKeyedTransient(AutomaticTaxFactory.PersonalUse); + services.AddKeyedTransient(AutomaticTaxFactory.BusinessUse); + services.AddTransient(); services.AddLicenseServices(); services.AddPricingClient(); } diff --git a/src/Core/Billing/Extensions/SubscriptionCreateOptionsExtensions.cs b/src/Core/Billing/Extensions/SubscriptionCreateOptionsExtensions.cs deleted file mode 100644 index d76a0553a3..0000000000 --- a/src/Core/Billing/Extensions/SubscriptionCreateOptionsExtensions.cs +++ /dev/null @@ -1,26 +0,0 @@ -using Stripe; - -namespace Bit.Core.Billing.Extensions; - -public static class SubscriptionCreateOptionsExtensions -{ - /// - /// Attempts to enable automatic tax for given new subscription options. - /// - /// - /// The existing customer. - /// Returns true when successful, false when conditions are not met. - public static bool EnableAutomaticTax(this SubscriptionCreateOptions options, Customer customer) - { - // We might only need to check the automatic tax status. - if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country)) - { - return false; - } - - options.DefaultTaxRates = []; - options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; - - return true; - } -} diff --git a/src/Core/Billing/Services/Contracts/AutomaticTaxFactoryParameters.cs b/src/Core/Billing/Services/Contracts/AutomaticTaxFactoryParameters.cs new file mode 100644 index 0000000000..19a4f0bdfa --- /dev/null +++ b/src/Core/Billing/Services/Contracts/AutomaticTaxFactoryParameters.cs @@ -0,0 +1,30 @@ +#nullable enable +using Bit.Core.Billing.Enums; +using Bit.Core.Entities; + +namespace Bit.Core.Billing.Services.Contracts; + +public class AutomaticTaxFactoryParameters +{ + public AutomaticTaxFactoryParameters(PlanType planType) + { + PlanType = planType; + } + + public AutomaticTaxFactoryParameters(ISubscriber subscriber, IEnumerable prices) + { + Subscriber = subscriber; + Prices = prices; + } + + public AutomaticTaxFactoryParameters(IEnumerable prices) + { + Prices = prices; + } + + public ISubscriber? Subscriber { get; init; } + + public PlanType? PlanType { get; init; } + + public IEnumerable? Prices { get; init; } +} diff --git a/src/Core/Billing/Services/IAutomaticTaxFactory.cs b/src/Core/Billing/Services/IAutomaticTaxFactory.cs new file mode 100644 index 0000000000..c52a8f2671 --- /dev/null +++ b/src/Core/Billing/Services/IAutomaticTaxFactory.cs @@ -0,0 +1,11 @@ +using Bit.Core.Billing.Services.Contracts; + +namespace Bit.Core.Billing.Services; + +/// +/// Responsible for defining the correct automatic tax strategy for either personal use of business use. +/// +public interface IAutomaticTaxFactory +{ + Task CreateAsync(AutomaticTaxFactoryParameters parameters); +} diff --git a/src/Core/Billing/Services/IAutomaticTaxStrategy.cs b/src/Core/Billing/Services/IAutomaticTaxStrategy.cs new file mode 100644 index 0000000000..292f2d0939 --- /dev/null +++ b/src/Core/Billing/Services/IAutomaticTaxStrategy.cs @@ -0,0 +1,33 @@ +#nullable enable +using Stripe; + +namespace Bit.Core.Billing.Services; + +public interface IAutomaticTaxStrategy +{ + /// + /// + /// + /// + /// + /// Returns if changes are to be applied to the subscription, returns null + /// otherwise. + /// + SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription); + + /// + /// Modifies an existing object with the automatic tax flag set correctly. + /// + /// + /// + void SetCreateOptions(SubscriptionCreateOptions options, Customer customer); + + /// + /// Modifies an existing object with the automatic tax flag set correctly. + /// + /// + /// + void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription); + + void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options); +} diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/AutomaticTaxFactory.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/AutomaticTaxFactory.cs new file mode 100644 index 0000000000..133cd2c7a7 --- /dev/null +++ b/src/Core/Billing/Services/Implementations/AutomaticTax/AutomaticTaxFactory.cs @@ -0,0 +1,50 @@ +#nullable enable +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services.Contracts; +using Bit.Core.Entities; +using Bit.Core.Services; + +namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; + +public class AutomaticTaxFactory( + IFeatureService featureService, + IPricingClient pricingClient) : IAutomaticTaxFactory +{ + public const string BusinessUse = "business-use"; + public const string PersonalUse = "personal-use"; + + private readonly Lazy>> _personalUsePlansTask = new(async () => + { + var plans = await Task.WhenAll( + pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), + pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually)); + + return plans.Select(plan => plan.PasswordManager.StripePlanId); + }); + + public async Task CreateAsync(AutomaticTaxFactoryParameters parameters) + { + if (parameters.Subscriber is User) + { + return new PersonalUseAutomaticTaxStrategy(featureService); + } + + if (parameters.PlanType.HasValue) + { + var plan = await pricingClient.GetPlanOrThrow(parameters.PlanType.Value); + return plan.CanBeUsedByBusiness + ? new BusinessUseAutomaticTaxStrategy(featureService) + : new PersonalUseAutomaticTaxStrategy(featureService); + } + + var personalUsePlans = await _personalUsePlansTask.Value; + + if (parameters.Prices != null && parameters.Prices.Any(x => personalUsePlans.Any(y => y == x))) + { + return new PersonalUseAutomaticTaxStrategy(featureService); + } + + return new BusinessUseAutomaticTaxStrategy(featureService); + } +} diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategy.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategy.cs new file mode 100644 index 0000000000..40eb6e4540 --- /dev/null +++ b/src/Core/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategy.cs @@ -0,0 +1,96 @@ +#nullable enable +using Bit.Core.Billing.Extensions; +using Bit.Core.Services; +using Stripe; + +namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; + +public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : IAutomaticTaxStrategy +{ + public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription) + { + if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + { + return null; + } + + var shouldBeEnabled = ShouldBeEnabled(subscription.Customer); + if (subscription.AutomaticTax.Enabled == shouldBeEnabled) + { + return null; + } + + var options = new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = shouldBeEnabled + }, + DefaultTaxRates = [] + }; + + return options; + } + + public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) + { + options.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = ShouldBeEnabled(customer) + }; + } + + public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription) + { + if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + { + return; + } + + var shouldBeEnabled = ShouldBeEnabled(subscription.Customer); + + if (subscription.AutomaticTax.Enabled == shouldBeEnabled) + { + return; + } + + options.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = shouldBeEnabled + }; + options.DefaultTaxRates = []; + } + + public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options) + { + options.AutomaticTax ??= new InvoiceAutomaticTaxOptions(); + + if (options.CustomerDetails.Address.Country == "US") + { + options.AutomaticTax.Enabled = true; + return; + } + + options.AutomaticTax.Enabled = options.CustomerDetails.TaxIds != null && options.CustomerDetails.TaxIds.Any(); + } + + private bool ShouldBeEnabled(Customer customer) + { + if (!customer.HasTaxLocationVerified()) + { + return false; + } + + if (customer.Address.Country == "US") + { + return true; + } + + if (customer.TaxIds == null) + { + throw new ArgumentNullException(nameof(customer.TaxIds), "`customer.tax_ids` must be expanded."); + } + + return customer.TaxIds.Any(); + } +} diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategy.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategy.cs new file mode 100644 index 0000000000..15ee1adf8f --- /dev/null +++ b/src/Core/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategy.cs @@ -0,0 +1,64 @@ +#nullable enable +using Bit.Core.Billing.Extensions; +using Bit.Core.Services; +using Stripe; + +namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; + +public class PersonalUseAutomaticTaxStrategy(IFeatureService featureService) : IAutomaticTaxStrategy +{ + public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) + { + options.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = ShouldBeEnabled(customer) + }; + } + + public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription) + { + if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + { + return; + } + options.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = ShouldBeEnabled(subscription.Customer) + }; + options.DefaultTaxRates = []; + } + + public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription) + { + if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + { + return null; + } + + if (subscription.AutomaticTax.Enabled == ShouldBeEnabled(subscription.Customer)) + { + return null; + } + + var options = new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = ShouldBeEnabled(subscription.Customer), + }, + DefaultTaxRates = [] + }; + + return options; + } + + public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options) + { + options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; + } + + private static bool ShouldBeEnabled(Customer customer) + { + return customer.HasTaxLocationVerified(); + } +} diff --git a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs index 8b773f1cef..a4d22cfa3e 100644 --- a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs +++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs @@ -1,9 +1,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -23,6 +25,7 @@ namespace Bit.Core.Billing.Services.Implementations; public class OrganizationBillingService( IBraintreeGateway braintreeGateway, + IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, IOrganizationRepository organizationRepository, @@ -30,7 +33,8 @@ public class OrganizationBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - ITaxService taxService) : IOrganizationBillingService + ITaxService taxService, + IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService { public async Task Finalize(OrganizationSale sale) { @@ -143,7 +147,7 @@ public class OrganizationBillingService( Coupon = customerSetup.Coupon, Description = organization.DisplayBusinessName(), Email = organization.BillingEmail, - Expand = ["tax"], + Expand = ["tax", "tax_ids"], InvoiceSettings = new CustomerInvoiceSettingsOptions { CustomFields = [ @@ -369,21 +373,8 @@ public class OrganizationBillingService( } } - var customerHasTaxInfo = customer is - { - Address: - { - Country: not null and not "", - PostalCode: not null and not "" - } - }; - var subscriptionCreateOptions = new SubscriptionCreateOptions { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = customerHasTaxInfo - }, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, @@ -395,6 +386,18 @@ public class OrganizationBillingService( TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays }; + if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType); + var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); + automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + } + else + { + subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions(); + subscriptionCreateOptions.AutomaticTax.Enabled = customer.HasBillingLocation(); + } + return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); } diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index c00a151aa1..6746a8cc98 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -2,6 +2,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -9,6 +10,7 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Braintree; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; using Customer = Stripe.Customer; @@ -20,12 +22,14 @@ using static Utilities; public class PremiumUserBillingService( IBraintreeGateway braintreeGateway, + IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - IUserRepository userRepository) : IPremiumUserBillingService + IUserRepository userRepository, + [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService { public async Task Credit(User user, decimal amount) { @@ -318,10 +322,6 @@ public class PremiumUserBillingService( var subscriptionCreateOptions = new SubscriptionCreateOptions { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported, - }, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, @@ -335,6 +335,18 @@ public class PremiumUserBillingService( OffSession = true }; + if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + } + else + { + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported, + }; + } + var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); if (usingPayPal) diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index b2dca19e80..e4b0594433 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -1,6 +1,7 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -20,11 +21,13 @@ namespace Bit.Core.Billing.Services.Implementations; public class SubscriberService( IBraintreeGateway braintreeGateway, + IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ITaxService taxService) : ISubscriberService + ITaxService taxService, + IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService { public async Task CancelSubscription( ISubscriber subscriber, @@ -438,7 +441,8 @@ public class SubscriberService( ArgumentNullException.ThrowIfNull(subscriber); ArgumentNullException.ThrowIfNull(tokenizedPaymentSource); - var customer = await GetCustomerOrThrow(subscriber); + var customerGetOptions = new CustomerGetOptions { Expand = ["tax", "tax_ids"] }; + var customer = await GetCustomerOrThrow(subscriber, customerGetOptions); var (type, token) = tokenizedPaymentSource; @@ -597,7 +601,7 @@ public class SubscriberService( Expand = ["subscriptions", "tax", "tax_ids"] }); - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + customer = await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Address = new AddressOptions { @@ -607,7 +611,8 @@ public class SubscriberService( Line2 = taxInformation.Line2, City = taxInformation.City, State = taxInformation.State - } + }, + Expand = ["subscriptions", "tax", "tax_ids"] }); var taxId = customer.TaxIds?.FirstOrDefault(); @@ -661,21 +666,42 @@ public class SubscriberService( } } - if (SubscriberIsEligibleForAutomaticTax(subscriber, customer)) + if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) { - await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, - new SubscriptionUpdateOptions + if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + var subscriptionGetOptions = new SubscriptionGetOptions { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); + Expand = ["customer.tax", "customer.tax_ids"] + }; + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); + var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); + var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription); + if (automaticTaxOptions?.AutomaticTax?.Enabled != null) + { + await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions); + } + } } + else + { + if (SubscriberIsEligibleForAutomaticTax(subscriber, customer)) + { + await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } - return; + return; - bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer) - => !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) && - (localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) && - localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; + bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer) + => !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) && + (localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) && + localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; + } } public async Task VerifyBankAccount( diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index b772002dbb..310b917bf7 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -148,6 +148,8 @@ public static class FeatureFlagKeys public const string P15179_AddExistingOrgsFromProviderPortal = "pm-15179-add-existing-orgs-from-provider-portal"; public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features"; public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method"; + public const string PM19147_AutomaticTaxImprovements = "pm-19147-automatic-tax-improvements"; + public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; /* Key Management Team */ public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; @@ -169,6 +171,7 @@ public static class FeatureFlagKeys public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication"; public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync"; public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias"; + public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias"; /* Platform Team */ diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index ca377407f4..cdcd14ca90 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -9,6 +9,8 @@ using Bit.Core.Billing.Models.Api.Responses; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Contracts; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -16,6 +18,7 @@ using Bit.Core.Models.BitStripe; using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Settings; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; using PaymentMethod = Stripe.PaymentMethod; @@ -36,6 +39,8 @@ public class StripePaymentService : IPaymentService private readonly ITaxService _taxService; private readonly ISubscriberService _subscriberService; private readonly IPricingClient _pricingClient; + private readonly IAutomaticTaxFactory _automaticTaxFactory; + private readonly IAutomaticTaxStrategy _personalUseTaxStrategy; public StripePaymentService( ITransactionRepository transactionRepository, @@ -46,7 +51,9 @@ public class StripePaymentService : IPaymentService IFeatureService featureService, ITaxService taxService, ISubscriberService subscriberService, - IPricingClient pricingClient) + IPricingClient pricingClient, + IAutomaticTaxFactory automaticTaxFactory, + [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy personalUseTaxStrategy) { _transactionRepository = transactionRepository; _logger = logger; @@ -57,6 +64,8 @@ public class StripePaymentService : IPaymentService _taxService = taxService; _subscriberService = subscriberService; _pricingClient = pricingClient; + _automaticTaxFactory = automaticTaxFactory; + _personalUseTaxStrategy = personalUseTaxStrategy; } private async Task ChangeOrganizationSponsorship( @@ -91,9 +100,7 @@ public class StripePaymentService : IPaymentService SubscriptionUpdate subscriptionUpdate, bool invoiceNow = false) { // remember, when in doubt, throw - var subGetOptions = new SubscriptionGetOptions(); - // subGetOptions.AddExpand("customer"); - subGetOptions.AddExpand("customer.tax"); + var subGetOptions = new SubscriptionGetOptions { Expand = ["customer.tax", "customer.tax_ids"] }; var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subGetOptions); if (sub == null) { @@ -124,7 +131,19 @@ public class StripePaymentService : IPaymentService new SubscriptionPendingInvoiceItemIntervalOptions { Interval = "month" }; } - subUpdateOptions.EnableAutomaticTax(sub.Customer, sub); + if (subscriptionUpdate is CompleteSubscriptionUpdate) + { + if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, updatedItemOptions.Select(x => x.Plan ?? x.Price)); + var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); + automaticTaxStrategy.SetUpdateOptions(subUpdateOptions, sub); + } + else + { + subUpdateOptions.EnableAutomaticTax(sub.Customer, sub); + } + } if (!subscriptionUpdate.UpdateNeeded(sub)) { @@ -811,21 +830,46 @@ public class StripePaymentService : IPaymentService }); } - if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) && - customer.Subscriptions.Any(sub => - sub.Id == subscriber.GatewaySubscriptionId && - !sub.AutomaticTax.Enabled) && - customer.HasTaxLocationVerified()) + if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) { - var subscriptionUpdateOptions = new SubscriptionUpdateOptions + if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, - DefaultTaxRates = [] - }; + var subscriptionGetOptions = new SubscriptionGetOptions + { + Expand = ["customer.tax", "customer.tax_ids"] + }; + var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); - _ = await _stripeAdapter.SubscriptionUpdateAsync( - subscriber.GatewaySubscriptionId, - subscriptionUpdateOptions); + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); + var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); + var subscriptionUpdateOptions = automaticTaxStrategy.GetUpdateOptions(subscription); + + if (subscriptionUpdateOptions != null) + { + _ = await _stripeAdapter.SubscriptionUpdateAsync( + subscriber.GatewaySubscriptionId, + subscriptionUpdateOptions); + } + } + } + else + { + if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) && + customer.Subscriptions.Any(sub => + sub.Id == subscriber.GatewaySubscriptionId && + !sub.AutomaticTax.Enabled) && + customer.HasTaxLocationVerified()) + { + var subscriptionUpdateOptions = new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, + DefaultTaxRates = [] + }; + + _ = await _stripeAdapter.SubscriptionUpdateAsync( + subscriber.GatewaySubscriptionId, + subscriptionUpdateOptions); + } } } catch @@ -1214,6 +1258,8 @@ public class StripePaymentService : IPaymentService } } + _personalUseTaxStrategy.SetInvoiceCreatePreviewOptions(options); + try { var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); @@ -1256,10 +1302,6 @@ public class StripePaymentService : IPaymentService var options = new InvoiceCreatePreviewOptions { - AutomaticTax = new InvoiceAutomaticTaxOptions - { - Enabled = true, - }, Currency = "usd", SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { @@ -1347,9 +1389,11 @@ public class StripePaymentService : IPaymentService ]; } + Customer gatewayCustomer = null; + if (!string.IsNullOrWhiteSpace(gatewayCustomerId)) { - var gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); if (gatewayCustomer.Discount != null) { @@ -1367,6 +1411,10 @@ public class StripePaymentService : IPaymentService } } + var automaticTaxFactoryParameters = new AutomaticTaxFactoryParameters(parameters.PasswordManager.Plan); + var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxFactoryParameters); + automaticTaxStrategy.SetInvoiceCreatePreviewOptions(options); + try { var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); diff --git a/test/Core.Test/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategyTests.cs b/test/Core.Test/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategyTests.cs new file mode 100644 index 0000000000..dc40656275 --- /dev/null +++ b/test/Core.Test/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategyTests.cs @@ -0,0 +1,492 @@ +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Services.Implementations.AutomaticTax; + +[SutProviderCustomize] +public class BusinessUseAutomaticTaxStrategyTests +{ + [Theory] + [BitAutoData] + public void GetUpdateOptions_ReturnsNull_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled( + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(false); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.Null(actual); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_ReturnsNull_WhenSubscriptionDoesNotNeedUpdating( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "US", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.Null(actual); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.False(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForAmericanCustomers( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = "US", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.True(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List + { + new() + { + Country = "ES", + Type = "eu_vat", + Value = "ESZ8880999Z" + } + } + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.True(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_ThrowsArgumentNullException_WhenTaxIdsIsNull( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = null + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + Assert.Throws(() => sutProvider.Sut.GetUpdateOptions(subscription)); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds( + SutProvider sutProvider) + { + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List() + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.False(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsNothing_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + Customer = new Customer + { + Address = new() + { + Country = "US" + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(false); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.Null(options.AutomaticTax); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsNothing_WhenSubscriptionDoesNotNeedUpdating( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "US", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.Null(options.AutomaticTax); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.False(options.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForAmericanCustomers( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = "US", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.True(options.AutomaticTax!.Enabled); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List + { + new() + { + Country = "ES", + Type = "eu_vat", + Value = "ESZ8880999Z" + } + } + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.True(options.AutomaticTax!.Enabled); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_ThrowsArgumentNullException_WhenTaxIdsIsNull( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = null + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + Assert.Throws(() => sutProvider.Sut.SetUpdateOptions(options, subscription)); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List() + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.False(options.AutomaticTax!.Enabled); + } +} diff --git a/test/Core.Test/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategyTests.cs b/test/Core.Test/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategyTests.cs new file mode 100644 index 0000000000..2d50c9f75a --- /dev/null +++ b/test/Core.Test/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategyTests.cs @@ -0,0 +1,217 @@ +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Services.Implementations.AutomaticTax; + +[SutProviderCustomize] +public class PersonalUseAutomaticTaxStrategyTests +{ + [Theory] + [BitAutoData] + public void GetUpdateOptions_ReturnsNull_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled( + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(false); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.Null(actual); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_ReturnsNull_WhenSubscriptionDoesNotNeedUpdating( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "US", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.Null(actual); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.False(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData("CA")] + [BitAutoData("ES")] + [BitAutoData("US")] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForAllCountries( + string country, SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = country + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.True(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData("CA")] + [BitAutoData("ES")] + [BitAutoData("US")] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds( + string country, SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = country, + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List + { + new() + { + Country = "ES", + Type = "eu_vat", + Value = "ESZ8880999Z" + } + } + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.True(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData("CA")] + [BitAutoData("ES")] + [BitAutoData("US")] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds( + string country, SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = country + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List() + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.True(actual.AutomaticTax.Enabled); + } +} diff --git a/test/Core.Test/Billing/Services/Implementations/AutomaticTaxFactoryTests.cs b/test/Core.Test/Billing/Services/Implementations/AutomaticTaxFactoryTests.cs new file mode 100644 index 0000000000..7d5c9c3a26 --- /dev/null +++ b/test/Core.Test/Billing/Services/Implementations/AutomaticTaxFactoryTests.cs @@ -0,0 +1,105 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services.Contracts; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; +using Bit.Core.Entities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Billing.Services.Implementations; + +[SutProviderCustomize] +public class AutomaticTaxFactoryTests +{ + [BitAutoData] + [Theory] + public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenSubscriberIsUser(SutProvider sut) + { + var parameters = new AutomaticTaxFactoryParameters(new User(), []); + + var actual = await sut.Sut.CreateAsync(parameters); + + Assert.IsType(actual); + } + + [BitAutoData] + [Theory] + public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenSubscriberIsOrganizationWithFamiliesAnnuallyPrice( + SutProvider sut) + { + var familiesPlan = new FamiliesPlan(); + var parameters = new AutomaticTaxFactoryParameters(new Organization(), [familiesPlan.PasswordManager.StripePlanId]); + + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) + .Returns(new FamiliesPlan()); + + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually2019)) + .Returns(new Families2019Plan()); + + var actual = await sut.Sut.CreateAsync(parameters); + + Assert.IsType(actual); + } + + [Theory] + [BitAutoData] + public async Task CreateAsync_ReturnsBusinessUseStrategy_WhenSubscriberIsOrganizationWithBusinessUsePrice( + EnterpriseAnnually plan, + SutProvider sut) + { + var parameters = new AutomaticTaxFactoryParameters(new Organization(), [plan.PasswordManager.StripePlanId]); + + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) + .Returns(new FamiliesPlan()); + + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually2019)) + .Returns(new Families2019Plan()); + + var actual = await sut.Sut.CreateAsync(parameters); + + Assert.IsType(actual); + } + + [Theory] + [BitAutoData] + public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenPlanIsMeantForPersonalUse(SutProvider sut) + { + var parameters = new AutomaticTaxFactoryParameters(PlanType.FamiliesAnnually); + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == parameters.PlanType.Value)) + .Returns(new FamiliesPlan()); + + var actual = await sut.Sut.CreateAsync(parameters); + + Assert.IsType(actual); + } + + [Theory] + [BitAutoData] + public async Task CreateAsync_ReturnsBusinessUseStrategy_WhenPlanIsMeantForBusinessUse(SutProvider sut) + { + var parameters = new AutomaticTaxFactoryParameters(PlanType.EnterpriseAnnually); + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == parameters.PlanType.Value)) + .Returns(new EnterprisePlan(true)); + + var actual = await sut.Sut.CreateAsync(parameters); + + Assert.IsType(actual); + } + + public record EnterpriseAnnually : EnterprisePlan + { + public EnterpriseAnnually() : base(true) + { + } + } +} diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 5b7a2cc8bd..9e4be78787 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -3,10 +3,13 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Enums; using Bit.Core.Services; using Bit.Core.Settings; +using Bit.Core.Test.Billing.Stubs; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Braintree; @@ -1167,7 +1170,9 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId) + stripeAdapter.CustomerGetAsync( + provider.GatewayCustomerId, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1213,7 +1218,10 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId) + stripeAdapter.CustomerGetAsync( + provider.GatewayCustomerId, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")) + ) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1321,7 +1329,9 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().CustomerGetAsync( + provider.GatewayCustomerId, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1373,7 +1383,9 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().CustomerGetAsync( + provider.GatewayCustomerId, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1482,7 +1494,9 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().CustomerGetAsync( + provider.GatewayCustomerId, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer { Id = provider.GatewayCustomerId @@ -1561,6 +1575,37 @@ public class SubscriberServiceTests "Example Town", "NY"); + sutProvider.GetDependency() + .CustomerUpdateAsync( + Arg.Is(p => p == provider.GatewayCustomerId), + Arg.Is(options => + options.Address.Country == "US" && + options.Address.PostalCode == "12345" && + options.Address.Line1 == "123 Example St." && + options.Address.Line2 == null && + options.Address.City == "Example Town" && + options.Address.State == "NY")) + .Returns(new Customer + { + Id = provider.GatewayCustomerId, + Address = new Address + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Example St.", + Line2 = null, + City = "Example Town", + State = "NY" + }, + TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } + }); + + var subscription = new Subscription { Items = new StripeList() }; + sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) + .Returns(subscription); + sutProvider.GetDependency().CreateAsync(Arg.Any()) + .Returns(new FakeAutomaticTaxStrategy(true)); + await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( diff --git a/test/Core.Test/Billing/Stubs/FakeAutomaticTaxStrategy.cs b/test/Core.Test/Billing/Stubs/FakeAutomaticTaxStrategy.cs new file mode 100644 index 0000000000..253aead5c7 --- /dev/null +++ b/test/Core.Test/Billing/Stubs/FakeAutomaticTaxStrategy.cs @@ -0,0 +1,35 @@ +using Bit.Core.Billing.Services; +using Stripe; + +namespace Bit.Core.Test.Billing.Stubs; + +/// +/// Whether the subscription options will have automatic tax enabled or not. +/// +public class FakeAutomaticTaxStrategy( + bool isAutomaticTaxEnabled) : IAutomaticTaxStrategy +{ + public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription) + { + return new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled } + }; + } + + public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) + { + options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled }; + } + + public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription) + { + options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled }; + } + + public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options) + { + options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled }; + + } +} From aef05f5fb67c8635238d4a690b9d075e554e9ca1 Mon Sep 17 00:00:00 2001 From: Jimmy Vo Date: Wed, 2 Apr 2025 15:23:31 -0400 Subject: [PATCH 04/15] [PM-19290] Skip the notification step if no admin emails are available. (#5582) --- .../Implementations/AuthRequestService.cs | 6 ++ .../Auth/Services/AuthRequestServiceTests.cs | 82 +++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/src/Core/Auth/Services/Implementations/AuthRequestService.cs b/src/Core/Auth/Services/Implementations/AuthRequestService.cs index 42d51a88f5..0fd1846d00 100644 --- a/src/Core/Auth/Services/Implementations/AuthRequestService.cs +++ b/src/Core/Auth/Services/Implementations/AuthRequestService.cs @@ -289,6 +289,12 @@ public class AuthRequestService : IAuthRequestService { var adminEmails = await GetAdminAndAccountRecoveryEmailsAsync(organizationUser.OrganizationId); + if (adminEmails.Count == 0) + { + _logger.LogWarning("There are no admin emails to send to."); + return; + } + await _mailService.SendDeviceApprovalRequestedNotificationEmailAsync( adminEmails, organizationUser.OrganizationId, diff --git a/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs b/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs index edd7a06fa7..eec6747c5f 100644 --- a/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs +++ b/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs @@ -17,6 +17,7 @@ using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; +using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; using GlobalSettings = Bit.Core.Settings.GlobalSettings; @@ -395,6 +396,87 @@ public class AuthRequestServiceTests user.Name); } + + [Theory, BitAutoData] + public async Task CreateAuthRequestAsync_AdminApproval_WithAdminNotifications_AndNoAdminEmails_ShouldNotSendNotificationEmails( + SutProvider sutProvider, + AuthRequestCreateRequestModel createModel, + User user, + OrganizationUser organizationUser1) + { + createModel.Type = AuthRequestType.AdminApproval; + user.Email = createModel.Email; + organizationUser1.UserId = user.Id; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DeviceApprovalRequestAdminNotifications) + .Returns(true); + + sutProvider.GetDependency() + .GetByEmailAsync(user.Email) + .Returns(user); + + sutProvider.GetDependency() + .DeviceType + .Returns(DeviceType.ChromeExtension); + + sutProvider.GetDependency() + .UserId + .Returns(user.Id); + + sutProvider.GetDependency() + .PasswordlessAuth.KnownDevicesOnly + .Returns(false); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns(new List + { + organizationUser1, + }); + + sutProvider.GetDependency() + .GetManyByMinimumRoleAsync(organizationUser1.OrganizationId, OrganizationUserType.Admin) + .Returns([]); + + sutProvider.GetDependency() + .GetManyDetailsByRoleAsync(organizationUser1.OrganizationId, OrganizationUserType.Custom) + .Returns([]); + + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(c => c.ArgAt(0)); + + var authRequest = await sutProvider.Sut.CreateAuthRequestAsync(createModel); + + Assert.Equal(organizationUser1.OrganizationId, authRequest.OrganizationId); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(Arg.Is(o => o.OrganizationId == organizationUser1.OrganizationId)); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(Arg.Any()); + + await sutProvider.GetDependency() + .Received(1) + .LogUserEventAsync(user.Id, EventType.User_RequestedDeviceApproval); + + await sutProvider.GetDependency() + .Received(0) + .SendDeviceApprovalRequestedNotificationEmailAsync( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + var expectedLogMessage = "There are no admin emails to send to."; + sutProvider.GetDependency>() + .Received(1) + .LogWarning(expectedLogMessage); + } + /// /// Story: When an is approved we want to update it in the database so it cannot have /// it's status changed again and we want to push a notification to let the user know of the approval. From 7b2b62e794e1b89703a8d8ca46254bb18b8005a0 Mon Sep 17 00:00:00 2001 From: Nick Krantz <125900171+nick-livefront@users.noreply.github.com> Date: Wed, 2 Apr 2025 15:18:53 -0500 Subject: [PATCH 05/15] [PM-18858] Security Task email plurality (#5588) * use handlebars helper for plurality of text rather than logic within the template * Remove `TaskCountPlural` - unused --- .../Handlebars/Layouts/SecurityTasks.html.hbs | 7 ++---- .../Handlebars/Layouts/SecurityTasks.text.hbs | 4 +--- .../Mail/SecurityTaskNotificationViewModel.cs | 2 -- .../Implementations/HandlebarsMailService.cs | 23 +++++++++++++++++++ 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.html.hbs b/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.html.hbs index 930d39eeee..67537b81a7 100644 --- a/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.html.hbs +++ b/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.html.hbs @@ -6,11 +6,8 @@ -
- {{OrgName}} has identified {{TaskCount}} critical login{{#if TaskCountPlural}}s{{/if}} that require{{#unless - TaskCountPlural}}s{{/unless}} a - password change + + {{OrgName}} has identified {{TaskCount}} critical {{plurality TaskCount "login" "logins"}} that {{plurality TaskCount "requires" "require"}} a password change
diff --git a/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.text.hbs b/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.text.hbs index f9befac46c..009e2b923f 100644 --- a/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.text.hbs +++ b/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.text.hbs @@ -1,7 +1,5 @@ {{#>FullTextLayout}} -{{OrgName}} has identified {{TaskCount}} critical login{{#if TaskCountPlural}}s{{/if}} that require{{#unless -TaskCountPlural}}s{{/unless}} a -password change +{{OrgName}} has identified {{TaskCount}} critical {{plurality TaskCount "login" "logins"}} that {{plurality TaskCount "requires" "require"}} a password change {{>@partial-block}} diff --git a/src/Core/Models/Mail/SecurityTaskNotificationViewModel.cs b/src/Core/Models/Mail/SecurityTaskNotificationViewModel.cs index 9b4ede6e01..d41ca41146 100644 --- a/src/Core/Models/Mail/SecurityTaskNotificationViewModel.cs +++ b/src/Core/Models/Mail/SecurityTaskNotificationViewModel.cs @@ -6,8 +6,6 @@ public class SecurityTaskNotificationViewModel : BaseMailModel public int TaskCount { get; set; } - public bool TaskCountPlural => TaskCount != 1; - public List AdminOwnerEmails { get; set; } public string ReviewPasswordsUrl => $"{WebVaultUrl}/browser-extension-prompt"; diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index 430636f44d..a551342324 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -794,6 +794,29 @@ public class HandlebarsMailService : IMailService writer.WriteSafeString($"{outputMessage}"); }); + + // Returns the singular or plural form of a word based on the provided numeric value. + Handlebars.RegisterHelper("plurality", (writer, context, parameters) => + { + if (parameters.Length != 3) + { + writer.WriteSafeString(string.Empty); + return; + } + + var numeric = parameters[0]; + var singularText = parameters[1].ToString(); + var pluralText = parameters[2].ToString(); + + if (numeric is int number) + { + writer.WriteSafeString(number == 1 ? singularText : pluralText); + } + else + { + writer.WriteSafeString(string.Empty); + } + }); } public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) From d4a3cd00befd4faf612e942b7a6dfa374c69054e Mon Sep 17 00:00:00 2001 From: Shane Melton Date: Wed, 2 Apr 2025 13:44:59 -0700 Subject: [PATCH 06/15] [PM-17563] Add missing TaskId and HubHelper for PendingSecurityTasks (#5591) * [PM-17563] Add case for PushType.PendingSecurityTasks * [PM-17563] Add missing TaskId property to NotificationStatusDetails and NotificationResponseModel * [PM-17563] Add migration script to re-create NotificationStatusDetailsView to include TaskId column * [PM-17563] Select explicit columns for NotificationStatusDetailsView and fix migration script --- .../Response/NotificationResponseModel.cs | 3 +++ .../Models/Data/NotificationStatusDetails.cs | 1 + .../NotificationStatusDetailsViewQuery.cs | 1 + src/Notifications/HubHelpers.cs | 5 ++++ .../Views/NotificationStatusDetailsView.sql | 18 ++++++++++--- .../NotificationsControllerTests.cs | 3 +++ .../NotificationResponseModelTests.cs | 2 ++ ...4-01_00_RecreateNotificationStatusView.sql | 25 +++++++++++++++++++ 8 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 util/Migrator/DbScripts/2025-04-01_00_RecreateNotificationStatusView.sql diff --git a/src/Api/NotificationCenter/Models/Response/NotificationResponseModel.cs b/src/Api/NotificationCenter/Models/Response/NotificationResponseModel.cs index 1ebed87de2..ab882d5557 100644 --- a/src/Api/NotificationCenter/Models/Response/NotificationResponseModel.cs +++ b/src/Api/NotificationCenter/Models/Response/NotificationResponseModel.cs @@ -22,6 +22,7 @@ public class NotificationResponseModel : ResponseModel Title = notificationStatusDetails.Title; Body = notificationStatusDetails.Body; Date = notificationStatusDetails.RevisionDate; + TaskId = notificationStatusDetails.TaskId; ReadDate = notificationStatusDetails.ReadDate; DeletedDate = notificationStatusDetails.DeletedDate; } @@ -40,6 +41,8 @@ public class NotificationResponseModel : ResponseModel public DateTime Date { get; set; } + public Guid? TaskId { get; set; } + public DateTime? ReadDate { get; set; } public DateTime? DeletedDate { get; set; } diff --git a/src/Core/NotificationCenter/Models/Data/NotificationStatusDetails.cs b/src/Core/NotificationCenter/Models/Data/NotificationStatusDetails.cs index d48985e725..5ad8decb94 100644 --- a/src/Core/NotificationCenter/Models/Data/NotificationStatusDetails.cs +++ b/src/Core/NotificationCenter/Models/Data/NotificationStatusDetails.cs @@ -19,6 +19,7 @@ public class NotificationStatusDetails public string? Body { get; set; } public DateTime CreationDate { get; set; } public DateTime RevisionDate { get; set; } + public Guid? TaskId { get; set; } // Notification Status fields public DateTime? ReadDate { get; set; } public DateTime? DeletedDate { get; set; } diff --git a/src/Infrastructure.EntityFramework/NotificationCenter/Repositories/Queries/NotificationStatusDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/NotificationCenter/Repositories/Queries/NotificationStatusDetailsViewQuery.cs index 2f8bade1d3..41f8610101 100644 --- a/src/Infrastructure.EntityFramework/NotificationCenter/Repositories/Queries/NotificationStatusDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/NotificationCenter/Repositories/Queries/NotificationStatusDetailsViewQuery.cs @@ -52,6 +52,7 @@ public class NotificationStatusDetailsViewQuery(Guid userId, ClientType clientTy ClientType = x.n.ClientType, UserId = x.n.UserId, OrganizationId = x.n.OrganizationId, + TaskId = x.n.TaskId, Title = x.n.Title, Body = x.n.Body, CreationDate = x.n.CreationDate, diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index 8fa74f7b84..441842da3b 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -135,6 +135,11 @@ public static class HubHelpers } break; + case PushType.PendingSecurityTasks: + var pendingTasksData = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); + await hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString()) + .SendAsync(_receiveMessageMethod, pendingTasksData, cancellationToken); + break; default: break; } diff --git a/src/Sql/NotificationCenter/dbo/Views/NotificationStatusDetailsView.sql b/src/Sql/NotificationCenter/dbo/Views/NotificationStatusDetailsView.sql index 5264be2009..57298152c7 100644 --- a/src/Sql/NotificationCenter/dbo/Views/NotificationStatusDetailsView.sql +++ b/src/Sql/NotificationCenter/dbo/Views/NotificationStatusDetailsView.sql @@ -1,10 +1,20 @@ CREATE VIEW [dbo].[NotificationStatusDetailsView] AS SELECT - N.*, - NS.UserId AS NotificationStatusUserId, - NS.ReadDate, - NS.DeletedDate + N.[Id], + N.[Priority], + N.[Global], + N.[ClientType], + N.[UserId], + N.[OrganizationId], + N.[Title], + N.[Body], + N.[CreationDate], + N.[RevisionDate], + N.[TaskId], + NS.[UserId] AS [NotificationStatusUserId], + NS.[ReadDate], + NS.[DeletedDate] FROM [dbo].[Notification] AS N LEFT JOIN diff --git a/test/Api.Test/NotificationCenter/Controllers/NotificationsControllerTests.cs b/test/Api.Test/NotificationCenter/Controllers/NotificationsControllerTests.cs index b8b21ef419..094ef2918e 100644 --- a/test/Api.Test/NotificationCenter/Controllers/NotificationsControllerTests.cs +++ b/test/Api.Test/NotificationCenter/Controllers/NotificationsControllerTests.cs @@ -67,6 +67,7 @@ public class NotificationsControllerTests Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date); Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate); Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate); + Assert.Equal(expectedNotificationStatusDetails.TaskId, notificationResponseModel.TaskId); }); Assert.Null(listResponse.ContinuationToken); @@ -116,6 +117,7 @@ public class NotificationsControllerTests Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date); Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate); Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate); + Assert.Equal(expectedNotificationStatusDetails.TaskId, notificationResponseModel.TaskId); }); Assert.Equal("2", listResponse.ContinuationToken); @@ -164,6 +166,7 @@ public class NotificationsControllerTests Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date); Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate); Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate); + Assert.Equal(expectedNotificationStatusDetails.TaskId, notificationResponseModel.TaskId); }); Assert.Null(listResponse.ContinuationToken); diff --git a/test/Api.Test/NotificationCenter/Models/Response/NotificationResponseModelTests.cs b/test/Api.Test/NotificationCenter/Models/Response/NotificationResponseModelTests.cs index f0dfc03fec..171b972575 100644 --- a/test/Api.Test/NotificationCenter/Models/Response/NotificationResponseModelTests.cs +++ b/test/Api.Test/NotificationCenter/Models/Response/NotificationResponseModelTests.cs @@ -26,6 +26,7 @@ public class NotificationResponseModelTests ClientType = ClientType.All, Title = "Test Title", Body = "Test Body", + TaskId = Guid.NewGuid(), RevisionDate = DateTime.UtcNow - TimeSpan.FromMinutes(3), ReadDate = DateTime.UtcNow - TimeSpan.FromMinutes(1), DeletedDate = DateTime.UtcNow, @@ -39,5 +40,6 @@ public class NotificationResponseModelTests Assert.Equal(model.Date, notificationStatusDetails.RevisionDate); Assert.Equal(model.ReadDate, notificationStatusDetails.ReadDate); Assert.Equal(model.DeletedDate, notificationStatusDetails.DeletedDate); + Assert.Equal(model.TaskId, notificationStatusDetails.TaskId); } } diff --git a/util/Migrator/DbScripts/2025-04-01_00_RecreateNotificationStatusView.sql b/util/Migrator/DbScripts/2025-04-01_00_RecreateNotificationStatusView.sql new file mode 100644 index 0000000000..727218f9ab --- /dev/null +++ b/util/Migrator/DbScripts/2025-04-01_00_RecreateNotificationStatusView.sql @@ -0,0 +1,25 @@ +-- Recreate the NotificationStatusView to include the Notification.TaskId column +CREATE OR ALTER VIEW [dbo].[NotificationStatusDetailsView] +AS +SELECT + N.[Id], + N.[Priority], + N.[Global], + N.[ClientType], + N.[UserId], + N.[OrganizationId], + N.[Title], + N.[Body], + N.[CreationDate], + N.[RevisionDate], + N.[TaskId], + NS.[UserId] AS [NotificationStatusUserId], + NS.[ReadDate], + NS.[DeletedDate] +FROM + [dbo].[Notification] AS N + LEFT JOIN + [dbo].[NotificationStatus] as NS +ON + N.[Id] = NS.[NotificationId] +GO From 0069866deaf23bc859973b43f9b0de875c9939ff Mon Sep 17 00:00:00 2001 From: Brandon Treston Date: Wed, 2 Apr 2025 17:07:05 -0400 Subject: [PATCH 07/15] override exempt status to include Invited (#5596) --- .../PolicyRequirements/ResetPasswordPolicyRequirement.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/ResetPasswordPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/ResetPasswordPolicyRequirement.cs index 4feef1b088..b7d0b14f15 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/ResetPasswordPolicyRequirement.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/ResetPasswordPolicyRequirement.cs @@ -34,6 +34,8 @@ public class ResetPasswordPolicyRequirementFactory : BasePolicyRequirementFactor protected override IEnumerable ExemptRoles => []; + protected override IEnumerable ExemptStatuses => [OrganizationUserStatusType.Revoked]; + public override ResetPasswordPolicyRequirement Create(IEnumerable policyDetails) { var result = policyDetails From 8fd48374dc4942fba8c8c0cd9525776a503eb8aa Mon Sep 17 00:00:00 2001 From: Bernd Schoolmann Date: Thu, 3 Apr 2025 11:30:49 +0200 Subject: [PATCH 08/15] [PM-2199] Implement userkey rotation for all TDE devices (#5446) * Implement userkey rotation v2 * Update request models * Cleanup * Update tests * Improve test * Add tests * Fix formatting * Fix test * Remove whitespace * Fix namespace * Enable nullable on models * Fix build * Add tests and enable nullable on masterpasswordunlockdatamodel * Fix test * Remove rollback * Add tests * Make masterpassword hint optional * Update user query * Add EF test * Improve test * Cleanup * Set masterpassword hint * Remove connection close * Add tests for invalid kdf types * Update test/Core.Test/KeyManagement/UserKey/RotateUserAccountKeysCommandTests.cs Co-authored-by: Thomas Avery <43214426+Thomas-Avery@users.noreply.github.com> * Fix formatting * Update src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs Co-authored-by: Thomas Avery <43214426+Thomas-Avery@users.noreply.github.com> * Update src/Api/Auth/Models/Request/Accounts/MasterPasswordUnlockDataModel.cs Co-authored-by: Thomas Avery <43214426+Thomas-Avery@users.noreply.github.com> * Update src/Api/Auth/Models/Request/Accounts/MasterPasswordUnlockDataModel.cs Co-authored-by: Thomas Avery <43214426+Thomas-Avery@users.noreply.github.com> * Update src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs Co-authored-by: Thomas Avery <43214426+Thomas-Avery@users.noreply.github.com> * Fix imports * Fix tests * Add poc for tde rotation * Improve rotation transaction safety * Add validator tests * Clean up validator * Add newline * Add devicekey unlock data to integration test * Fix tests * Fix tests * Remove null check * Remove null check * Fix IsTrusted returning wrong result * Add rollback * Cleanup * Address feedback * Further renames --------- Co-authored-by: Thomas Avery <43214426+Thomas-Avery@users.noreply.github.com> --- src/Api/Controllers/DevicesController.cs | 10 +--- .../AccountsKeyManagementController.cs | 7 ++- .../Models/Requests/UnlockDataRequestModel.cs | 2 + .../Validators/DeviceRotationValidator.cs | 53 +++++++++++++++++++ src/Api/Startup.cs | 5 +- .../Request/DeviceKeysUpdateRequestModel.cs | 8 +++ .../DeviceAuthRequestResponseModel.cs | 3 +- .../Models/Data/RotateUserAccountKeysData.cs | 1 + .../RotateUserAccountkeysCommand.cs | 8 +++ src/Core/Repositories/IDeviceRepository.cs | 2 + .../Repositories/DeviceRepository.cs | 33 ++++++++++++ .../Repositories/DeviceRepository.cs | 27 ++++++++++ .../AccountsKeyManagementControllerTests.cs | 4 ++ .../DeviceRotationValidatorTests.cs | 49 +++++++++++++++++ 14 files changed, 199 insertions(+), 13 deletions(-) create mode 100644 src/Api/KeyManagement/Validators/DeviceRotationValidator.cs create mode 100644 test/Api.Test/KeyManagement/Validators/DeviceRotationValidatorTests.cs diff --git a/src/Api/Controllers/DevicesController.cs b/src/Api/Controllers/DevicesController.cs index 02eb2d36d5..4e21b5e9dc 100644 --- a/src/Api/Controllers/DevicesController.cs +++ b/src/Api/Controllers/DevicesController.cs @@ -1,6 +1,5 @@ using System.ComponentModel.DataAnnotations; using Bit.Api.Auth.Models.Request; -using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Models.Request; using Bit.Api.Models.Response; using Bit.Core.Auth.Models.Api.Request; @@ -125,7 +124,7 @@ public class DevicesController : Controller } [HttpPost("{identifier}/retrieve-keys")] - public async Task GetDeviceKeys(string identifier, [FromBody] SecretVerificationRequestModel model) + public async Task GetDeviceKeys(string identifier) { var user = await _userService.GetUserByPrincipalAsync(User); @@ -134,14 +133,7 @@ public class DevicesController : Controller throw new UnauthorizedAccessException(); } - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "User verification failed."); - } - var device = await _deviceRepository.GetByIdentifierAsync(identifier, user.Id); - if (device == null) { throw new NotFoundException(); diff --git a/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs b/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs index 85e0981f22..0764e2ee28 100644 --- a/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs +++ b/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs @@ -8,6 +8,7 @@ using Bit.Api.Tools.Models.Request; using Bit.Api.Vault.Models.Request; using Bit.Core; using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Models.Api.Request; using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; using Bit.Core.Exceptions; @@ -43,6 +44,7 @@ public class AccountsKeyManagementController : Controller _organizationUserValidator; private readonly IRotationValidator, IEnumerable> _webauthnKeyValidator; + private readonly IRotationValidator, IEnumerable> _deviceValidator; public AccountsKeyManagementController(IUserService userService, IFeatureService featureService, @@ -57,7 +59,8 @@ public class AccountsKeyManagementController : Controller emergencyAccessValidator, IRotationValidator, IReadOnlyList> organizationUserValidator, - IRotationValidator, IEnumerable> webAuthnKeyValidator) + IRotationValidator, IEnumerable> webAuthnKeyValidator, + IRotationValidator, IEnumerable> deviceValidator) { _userService = userService; _featureService = featureService; @@ -71,6 +74,7 @@ public class AccountsKeyManagementController : Controller _emergencyAccessValidator = emergencyAccessValidator; _organizationUserValidator = organizationUserValidator; _webauthnKeyValidator = webAuthnKeyValidator; + _deviceValidator = deviceValidator; } [HttpPost("regenerate-keys")] @@ -109,6 +113,7 @@ public class AccountsKeyManagementController : Controller EmergencyAccesses = await _emergencyAccessValidator.ValidateAsync(user, model.AccountUnlockData.EmergencyAccessUnlockData), OrganizationUsers = await _organizationUserValidator.ValidateAsync(user, model.AccountUnlockData.OrganizationAccountRecoveryUnlockData), WebAuthnKeys = await _webauthnKeyValidator.ValidateAsync(user, model.AccountUnlockData.PasskeyUnlockData), + DeviceKeys = await _deviceValidator.ValidateAsync(user, model.AccountUnlockData.DeviceKeyUnlockData), Ciphers = await _cipherValidator.ValidateAsync(user, model.AccountData.Ciphers), Folders = await _folderValidator.ValidateAsync(user, model.AccountData.Folders), diff --git a/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs b/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs index 5156e2a655..23c3eb95d0 100644 --- a/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs @@ -3,6 +3,7 @@ using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.Auth.Models.Request; using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Auth.Models.Request.WebAuthn; +using Bit.Core.Auth.Models.Api.Request; namespace Bit.Api.KeyManagement.Models.Requests; @@ -13,4 +14,5 @@ public class UnlockDataRequestModel public required IEnumerable EmergencyAccessUnlockData { get; set; } public required IEnumerable OrganizationAccountRecoveryUnlockData { get; set; } public required IEnumerable PasskeyUnlockData { get; set; } + public required IEnumerable DeviceKeyUnlockData { get; set; } } diff --git a/src/Api/KeyManagement/Validators/DeviceRotationValidator.cs b/src/Api/KeyManagement/Validators/DeviceRotationValidator.cs new file mode 100644 index 0000000000..cbaf508766 --- /dev/null +++ b/src/Api/KeyManagement/Validators/DeviceRotationValidator.cs @@ -0,0 +1,53 @@ +using Bit.Core.Auth.Models.Api.Request; +using Bit.Core.Auth.Utilities; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; + +namespace Bit.Api.KeyManagement.Validators; + +/// +/// Device implementation for +/// +public class DeviceRotationValidator : IRotationValidator, IEnumerable> +{ + private readonly IDeviceRepository _deviceRepository; + + /// + /// Instantiates a new + /// + /// Retrieves all user s + public DeviceRotationValidator(IDeviceRepository deviceRepository) + { + _deviceRepository = deviceRepository; + } + + public async Task> ValidateAsync(User user, IEnumerable devices) + { + var result = new List(); + + var existingTrustedDevices = (await _deviceRepository.GetManyByUserIdAsync(user.Id)).Where(d => d.IsTrusted()).ToList(); + if (existingTrustedDevices.Count == 0) + { + return result; + } + + foreach (var existing in existingTrustedDevices) + { + var device = devices.FirstOrDefault(c => c.DeviceId == existing.Id); + if (device == null) + { + throw new BadRequestException("All existing trusted devices must be included in the rotation."); + } + + if (device.EncryptedUserKey == null || device.EncryptedPublicKey == null) + { + throw new BadRequestException("Rotated encryption keys must be provided for all devices that are trusted."); + } + + result.Add(device.ToDevice(existing)); + } + + return result; + } +} diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index 5849bfb634..deac7bf0c9 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -31,7 +31,7 @@ using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Identity.TokenProviders; using Bit.Core.Tools.ImportFeatures; using Bit.Core.Tools.ReportFeatures; - +using Bit.Core.Auth.Models.Api.Request; #if !OSS using Bit.Commercial.Core.SecretsManager; @@ -168,6 +168,9 @@ public class Startup services .AddScoped, IEnumerable>, WebAuthnLoginKeyRotationValidator>(); + services + .AddScoped, IEnumerable>, + DeviceRotationValidator>(); // Services services.AddBaseServices(globalSettings); diff --git a/src/Core/Auth/Models/Api/Request/DeviceKeysUpdateRequestModel.cs b/src/Core/Auth/Models/Api/Request/DeviceKeysUpdateRequestModel.cs index 2b815afd16..111b03a3a3 100644 --- a/src/Core/Auth/Models/Api/Request/DeviceKeysUpdateRequestModel.cs +++ b/src/Core/Auth/Models/Api/Request/DeviceKeysUpdateRequestModel.cs @@ -1,4 +1,5 @@ using System.ComponentModel.DataAnnotations; +using Bit.Core.Entities; using Bit.Core.Utilities; namespace Bit.Core.Auth.Models.Api.Request; @@ -7,6 +8,13 @@ public class OtherDeviceKeysUpdateRequestModel : DeviceKeysUpdateRequestModel { [Required] public Guid DeviceId { get; set; } + + public Device ToDevice(Device existingDevice) + { + existingDevice.EncryptedPublicKey = EncryptedPublicKey; + existingDevice.EncryptedUserKey = EncryptedUserKey; + return existingDevice; + } } public class DeviceKeysUpdateRequestModel diff --git a/src/Core/Auth/Models/Api/Response/DeviceAuthRequestResponseModel.cs b/src/Core/Auth/Models/Api/Response/DeviceAuthRequestResponseModel.cs index 3cfea51ee3..59630a6d2c 100644 --- a/src/Core/Auth/Models/Api/Response/DeviceAuthRequestResponseModel.cs +++ b/src/Core/Auth/Models/Api/Response/DeviceAuthRequestResponseModel.cs @@ -1,5 +1,4 @@ using Bit.Core.Auth.Models.Data; -using Bit.Core.Auth.Utilities; using Bit.Core.Enums; using Bit.Core.Models.Api; @@ -19,7 +18,7 @@ public class DeviceAuthRequestResponseModel : ResponseModel Type = deviceAuthDetails.Type, Identifier = deviceAuthDetails.Identifier, CreationDate = deviceAuthDetails.CreationDate, - IsTrusted = deviceAuthDetails.IsTrusted() + IsTrusted = deviceAuthDetails.IsTrusted, }; if (deviceAuthDetails.AuthRequestId != null && deviceAuthDetails.AuthRequestCreatedAt != null) diff --git a/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs b/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs index 7cb1c273a3..f81baf6fab 100644 --- a/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs +++ b/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs @@ -20,6 +20,7 @@ public class RotateUserAccountKeysData public IEnumerable EmergencyAccesses { get; set; } public IReadOnlyList OrganizationUsers { get; set; } public IEnumerable WebAuthnKeys { get; set; } + public IEnumerable DeviceKeys { get; set; } // User vault data encrypted by the userkey public IEnumerable Ciphers { get; set; } diff --git a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs index f4dcf31d5c..6967c9bf85 100644 --- a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs +++ b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs @@ -20,6 +20,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand private readonly ISendRepository _sendRepository; private readonly IEmergencyAccessRepository _emergencyAccessRepository; private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IDeviceRepository _deviceRepository; private readonly IPushNotificationService _pushService; private readonly IdentityErrorDescriber _identityErrorDescriber; private readonly IWebAuthnCredentialRepository _credentialRepository; @@ -42,6 +43,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand public RotateUserAccountKeysCommand(IUserService userService, IUserRepository userRepository, ICipherRepository cipherRepository, IFolderRepository folderRepository, ISendRepository sendRepository, IEmergencyAccessRepository emergencyAccessRepository, IOrganizationUserRepository organizationUserRepository, + IDeviceRepository deviceRepository, IPasswordHasher passwordHasher, IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository) { @@ -52,6 +54,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand _sendRepository = sendRepository; _emergencyAccessRepository = emergencyAccessRepository; _organizationUserRepository = organizationUserRepository; + _deviceRepository = deviceRepository; _pushService = pushService; _identityErrorDescriber = errors; _credentialRepository = credentialRepository; @@ -127,6 +130,11 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand saveEncryptedDataActions.Add(_credentialRepository.UpdateKeysForRotationAsync(user.Id, model.WebAuthnKeys)); } + if (model.DeviceKeys.Any()) + { + saveEncryptedDataActions.Add(_deviceRepository.UpdateKeysForRotationAsync(user.Id, model.DeviceKeys)); + } + await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions); await _pushService.PushLogOutAsync(user.Id); return IdentityResult.Success; diff --git a/src/Core/Repositories/IDeviceRepository.cs b/src/Core/Repositories/IDeviceRepository.cs index c9809c1de6..fc2f1556b7 100644 --- a/src/Core/Repositories/IDeviceRepository.cs +++ b/src/Core/Repositories/IDeviceRepository.cs @@ -1,5 +1,6 @@ using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; +using Bit.Core.KeyManagement.UserKey; #nullable enable @@ -16,4 +17,5 @@ public interface IDeviceRepository : IRepository // other requests. Task> GetManyByUserIdWithDeviceAuth(Guid userId); Task ClearPushTokenAsync(Guid id); + UpdateEncryptedDataForKeyRotation UpdateKeysForRotationAsync(Guid userId, IEnumerable devices); } diff --git a/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs b/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs index 4abf4a4649..723200ff1c 100644 --- a/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs @@ -1,8 +1,10 @@ using System.Data; using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; +using Bit.Core.KeyManagement.UserKey; using Bit.Core.Repositories; using Bit.Core.Settings; +using Bit.Core.Utilities; using Dapper; using Microsoft.Data.SqlClient; @@ -109,4 +111,35 @@ public class DeviceRepository : Repository, IDeviceRepository commandType: CommandType.StoredProcedure); } } + + public UpdateEncryptedDataForKeyRotation UpdateKeysForRotationAsync(Guid userId, IEnumerable devices) + { + return async (SqlConnection connection, SqlTransaction transaction) => + { + const string sql = @" + UPDATE D + SET + D.[EncryptedPublicKey] = UD.[encryptedPublicKey], + D.[EncryptedUserKey] = UD.[encryptedUserKey] + FROM + [dbo].[Device] D + INNER JOIN + OPENJSON(@DeviceCredentials) + WITH ( + id UNIQUEIDENTIFIER, + encryptedPublicKey NVARCHAR(MAX), + encryptedUserKey NVARCHAR(MAX) + ) UD + ON UD.[id] = D.[Id] + WHERE + D.[UserId] = @UserId"; + var deviceCredentials = CoreHelpers.ClassToJsonData(devices); + + await connection.ExecuteAsync( + sql, + new { UserId = userId, DeviceCredentials = deviceCredentials }, + transaction: transaction, + commandType: CommandType.Text); + }; + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs b/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs index ad31d0fb8b..19f38c6098 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs @@ -1,5 +1,6 @@ using AutoMapper; using Bit.Core.Auth.Models.Data; +using Bit.Core.KeyManagement.UserKey; using Bit.Core.Repositories; using Bit.Core.Settings; using Bit.Infrastructure.EntityFramework.Auth.Repositories.Queries; @@ -91,4 +92,30 @@ public class DeviceRepository : Repository, return await query.GetQuery(dbContext, userId, expirationMinutes).ToListAsync(); } } + + public UpdateEncryptedDataForKeyRotation UpdateKeysForRotationAsync(Guid userId, IEnumerable devices) + { + return async (_, _) => + { + var deviceUpdates = devices.ToList(); + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + var userDevices = await GetDbSet(dbContext) + .Where(device => device.UserId == userId) + .ToListAsync(); + var userDevicesWithUpdatesPending = userDevices + .Where(existingDevice => deviceUpdates.Any(updatedDevice => updatedDevice.Id == existingDevice.Id)) + .ToList(); + + foreach (var deviceToUpdate in userDevicesWithUpdatesPending) + { + var deviceUpdate = deviceUpdates.First(deviceUpdate => deviceUpdate.Id == deviceToUpdate.Id); + deviceToUpdate.EncryptedPublicKey = deviceUpdate.EncryptedPublicKey; + deviceToUpdate.EncryptedUserKey = deviceUpdate.EncryptedUserKey; + } + + await dbContext.SaveChangesAsync(); + }; + } + } diff --git a/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs b/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs index 7c05e1d680..1b065adbd6 100644 --- a/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs +++ b/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs @@ -29,6 +29,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture _passwordHasher; private string _ownerEmail = null!; @@ -40,6 +41,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture(); + _deviceRepository = _factory.GetService(); _emergencyAccessRepository = _factory.GetService(); _organizationUserRepository = _factory.GetService(); _passwordHasher = _factory.GetService>(); @@ -238,10 +240,12 @@ public class AccountsKeyManagementControllerTests : IClassFixture sutProvider, User user, IEnumerable devices) + { + var userCiphers = devices.Select(c => new Device { Id = c.DeviceId, EncryptedPrivateKey = "EncryptedPrivateKey", EncryptedPublicKey = "EncryptedPublicKey", EncryptedUserKey = "EncryptedUserKey" }).ToList(); + sutProvider.GetDependency().GetManyByUserIdAsync(user.Id) + .Returns(userCiphers); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.ValidateAsync(user, Enumerable.Empty())); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_SentDevicesTrustedButDatabaseUntrusted_Throws( + SutProvider sutProvider, User user, IEnumerable devices) + { + var userCiphers = devices.Select(c => new Device { Id = c.DeviceId, EncryptedPrivateKey = "Key", EncryptedPublicKey = "Key", EncryptedUserKey = "Key" }).ToList(); + sutProvider.GetDependency().GetManyByUserIdAsync(user.Id) + .Returns(userCiphers); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.ValidateAsync(user, [ + new OtherDeviceKeysUpdateRequestModel { DeviceId = userCiphers.First().Id, EncryptedPublicKey = null, EncryptedUserKey = null } + ])); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_Validates( + SutProvider sutProvider, User user, IEnumerable devices) + { + var userCiphers = devices.Select(c => new Device { Id = c.DeviceId, EncryptedPrivateKey = "Key", EncryptedPublicKey = "Key", EncryptedUserKey = "Key" }).ToList().Slice(0, 1); + sutProvider.GetDependency().GetManyByUserIdAsync(user.Id) + .Returns(userCiphers); + Assert.NotEmpty(await sutProvider.Sut.ValidateAsync(user, [ + new OtherDeviceKeysUpdateRequestModel { DeviceId = userCiphers.First().Id, EncryptedPublicKey = "Key", EncryptedUserKey = "Key" } + ])); + } +} From 1cc854ddb90e7e638ec9fcd7651562ec86455aa8 Mon Sep 17 00:00:00 2001 From: Github Actions Date: Thu, 3 Apr 2025 12:35:46 +0000 Subject: [PATCH 09/15] Bumped version to 2025.4.0 --- Directory.Build.props | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Directory.Build.props b/Directory.Build.props index 2ede6ad8d1..858abb2bc8 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2025.3.3 + 2025.4.0 Bit.$(MSBuildProjectName) enable From 282e80ca026e1d1bd7b28be9bcfa4c1f4ed618ba Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Thu, 3 Apr 2025 08:51:09 -0400 Subject: [PATCH 10/15] [PM-13837] Switch provider price IDs (#5518) * Add ProviderPriceAdapter This is a temporary utility that will be used to manage retrieval of provider price IDs until all providers can be migrated to the new price structure. * Updated ProviderBillingService.ChangePlan * Update ProviderBillingService.SetupSubscription * Update ProviderBillingService.UpdateSeatMinimums * Update ProviderBillingService.CurrySeatScalingUpdate * Mark StripeProviderPortalSeatPlanId obsolete * Run dotnet format --- .../Billing/ProviderBillingService.cs | 130 +++++++------- .../Billing/ProviderPriceAdapter.cs | 133 ++++++++++++++ .../Billing/ProviderBillingServiceTests.cs | 169 +++++++++++++----- .../Billing/ProviderPriceAdapterTests.cs | 151 ++++++++++++++++ .../Controllers/ProvidersController.cs | 10 +- .../Implementations/ProviderMigrator.cs | 3 +- src/Core/Billing/Models/StaticStore/Plan.cs | 1 + .../Contracts/ChangeProviderPlansCommand.cs | 7 +- .../UpdateProviderSeatMinimumsCommand.cs | 8 +- .../Business/ProviderSubscriptionUpdate.cs | 62 ------- src/Core/Services/IPaymentService.cs | 6 - .../Implementations/StripePaymentService.cs | 13 -- 12 files changed, 480 insertions(+), 213 deletions(-) create mode 100644 bitwarden_license/src/Commercial.Core/Billing/ProviderPriceAdapter.cs create mode 100644 bitwarden_license/test/Commercial.Core.Test/Billing/ProviderPriceAdapterTests.cs delete mode 100644 src/Core/Models/Business/ProviderSubscriptionUpdate.cs diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index 65e41ab586..757d6510f1 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -35,7 +35,6 @@ public class ProviderBillingService( IGlobalSettings globalSettings, ILogger logger, IOrganizationRepository organizationRepository, - IPaymentService paymentService, IPricingClient pricingClient, IProviderInvoiceItemRepository providerInvoiceItemRepository, IProviderOrganizationRepository providerOrganizationRepository, @@ -148,36 +147,29 @@ public class ProviderBillingService( public async Task ChangePlan(ChangeProviderPlanCommand command) { - var plan = await providerPlanRepository.GetByIdAsync(command.ProviderPlanId); + var (provider, providerPlanId, newPlanType) = command; - if (plan == null) + var providerPlan = await providerPlanRepository.GetByIdAsync(providerPlanId); + + if (providerPlan == null) { throw new BadRequestException("Provider plan not found."); } - if (plan.PlanType == command.NewPlan) + if (providerPlan.PlanType == newPlanType) { return; } - var oldPlanConfiguration = await pricingClient.GetPlanOrThrow(plan.PlanType); - var newPlanConfiguration = await pricingClient.GetPlanOrThrow(command.NewPlan); + var subscription = await subscriberService.GetSubscriptionOrThrow(provider); - plan.PlanType = command.NewPlan; - await providerPlanRepository.ReplaceAsync(plan); + var oldPriceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType); + var newPriceId = ProviderPriceAdapter.GetPriceId(provider, subscription, newPlanType); - Subscription subscription; - try - { - subscription = await stripeAdapter.ProviderSubscriptionGetAsync(command.GatewaySubscriptionId, plan.ProviderId); - } - catch (InvalidOperationException) - { - throw new ConflictException("Subscription not found."); - } + providerPlan.PlanType = newPlanType; + await providerPlanRepository.ReplaceAsync(providerPlan); - var oldSubscriptionItem = subscription.Items.SingleOrDefault(x => - x.Price.Id == oldPlanConfiguration.PasswordManager.StripeProviderPortalSeatPlanId); + var oldSubscriptionItem = subscription.Items.SingleOrDefault(x => x.Price.Id == oldPriceId); var updateOptions = new SubscriptionUpdateOptions { @@ -185,7 +177,7 @@ public class ProviderBillingService( [ new SubscriptionItemOptions { - Price = newPlanConfiguration.PasswordManager.StripeProviderPortalSeatPlanId, + Price = newPriceId, Quantity = oldSubscriptionItem!.Quantity }, new SubscriptionItemOptions @@ -196,12 +188,14 @@ public class ProviderBillingService( ] }; - await stripeAdapter.SubscriptionUpdateAsync(command.GatewaySubscriptionId, updateOptions); + await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, updateOptions); // Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId) // 1. Retrieve PlanType and PlanName for ProviderPlan // 2. Assign PlanType & PlanName to Organization - var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(plan.ProviderId); + var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId); + + var newPlan = await pricingClient.GetPlanOrThrow(newPlanType); foreach (var providerOrganization in providerOrganizations) { @@ -210,8 +204,8 @@ public class ProviderBillingService( { throw new ConflictException($"Organization '{providerOrganization.Id}' not found."); } - organization.PlanType = command.NewPlan; - organization.Plan = newPlanConfiguration.Name; + organization.PlanType = newPlanType; + organization.Plan = newPlan.Name; await organizationRepository.ReplaceAsync(organization); } } @@ -405,7 +399,7 @@ public class ProviderBillingService( var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment; - var update = CurrySeatScalingUpdate( + var scaleQuantityTo = CurrySeatScalingUpdate( provider, providerPlan, newlyAssignedSeatTotal); @@ -428,9 +422,7 @@ public class ProviderBillingService( else if (currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal > seatMinimum) { - await update( - seatMinimum, - newlyAssignedSeatTotal); + await scaleQuantityTo(newlyAssignedSeatTotal); } /* * Above the limit => Above the limit: @@ -439,9 +431,7 @@ public class ProviderBillingService( else if (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum) { - await update( - currentlyAssignedSeatTotal, - newlyAssignedSeatTotal); + await scaleQuantityTo(newlyAssignedSeatTotal); } /* * Above the limit => Below the limit: @@ -450,9 +440,7 @@ public class ProviderBillingService( else if (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal <= seatMinimum) { - await update( - currentlyAssignedSeatTotal, - seatMinimum); + await scaleQuantityTo(seatMinimum); } } @@ -586,9 +574,11 @@ public class ProviderBillingService( throw new BillingException(); } + var priceId = ProviderPriceAdapter.GetActivePriceId(provider, providerPlan.PlanType); + subscriptionItemOptionsList.Add(new SubscriptionItemOptions { - Price = plan.PasswordManager.StripeProviderPortalSeatPlanId, + Price = priceId, Quantity = providerPlan.SeatMinimum }); } @@ -654,43 +644,37 @@ public class ProviderBillingService( public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command) { - if (command.Configuration.Any(x => x.SeatsMinimum < 0)) + var (provider, updatedPlanConfigurations) = command; + + if (updatedPlanConfigurations.Any(x => x.SeatsMinimum < 0)) { throw new BadRequestException("Provider seat minimums must be at least 0."); } - Subscription subscription; - try - { - subscription = await stripeAdapter.ProviderSubscriptionGetAsync(command.GatewaySubscriptionId, command.Id); - } - catch (InvalidOperationException) - { - throw new ConflictException("Subscription not found."); - } + var subscription = await subscriberService.GetSubscriptionOrThrow(provider); var subscriptionItemOptionsList = new List(); - var providerPlans = await providerPlanRepository.GetByProviderId(command.Id); + var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - foreach (var newPlanConfiguration in command.Configuration) + foreach (var updatedPlanConfiguration in updatedPlanConfigurations) { + var (updatedPlanType, updatedSeatMinimum) = updatedPlanConfiguration; + var providerPlan = - providerPlans.Single(providerPlan => providerPlan.PlanType == newPlanConfiguration.Plan); + providerPlans.Single(providerPlan => providerPlan.PlanType == updatedPlanType); - if (providerPlan.SeatMinimum != newPlanConfiguration.SeatsMinimum) + if (providerPlan.SeatMinimum != updatedSeatMinimum) { - var newPlan = await pricingClient.GetPlanOrThrow(newPlanConfiguration.Plan); - - var priceId = newPlan.PasswordManager.StripeProviderPortalSeatPlanId; + var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, updatedPlanType); var subscriptionItem = subscription.Items.First(item => item.Price.Id == priceId); if (providerPlan.PurchasedSeats == 0) { - if (providerPlan.AllocatedSeats > newPlanConfiguration.SeatsMinimum) + if (providerPlan.AllocatedSeats > updatedSeatMinimum) { - providerPlan.PurchasedSeats = providerPlan.AllocatedSeats - newPlanConfiguration.SeatsMinimum; + providerPlan.PurchasedSeats = providerPlan.AllocatedSeats - updatedSeatMinimum; subscriptionItemOptionsList.Add(new SubscriptionItemOptions { @@ -705,7 +689,7 @@ public class ProviderBillingService( { Id = subscriptionItem.Id, Price = priceId, - Quantity = newPlanConfiguration.SeatsMinimum + Quantity = updatedSeatMinimum }); } } @@ -713,9 +697,9 @@ public class ProviderBillingService( { var totalSeats = providerPlan.SeatMinimum + providerPlan.PurchasedSeats; - if (newPlanConfiguration.SeatsMinimum <= totalSeats) + if (updatedSeatMinimum <= totalSeats) { - providerPlan.PurchasedSeats = totalSeats - newPlanConfiguration.SeatsMinimum; + providerPlan.PurchasedSeats = totalSeats - updatedSeatMinimum; } else { @@ -724,12 +708,12 @@ public class ProviderBillingService( { Id = subscriptionItem.Id, Price = priceId, - Quantity = newPlanConfiguration.SeatsMinimum + Quantity = updatedSeatMinimum }); } } - providerPlan.SeatMinimum = newPlanConfiguration.SeatsMinimum; + providerPlan.SeatMinimum = updatedSeatMinimum; await providerPlanRepository.ReplaceAsync(providerPlan); } @@ -737,23 +721,33 @@ public class ProviderBillingService( if (subscriptionItemOptionsList.Count > 0) { - await stripeAdapter.SubscriptionUpdateAsync(command.GatewaySubscriptionId, + await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { Items = subscriptionItemOptionsList }); } } - private Func CurrySeatScalingUpdate( + private Func CurrySeatScalingUpdate( Provider provider, ProviderPlan providerPlan, - int newlyAssignedSeats) => async (currentlySubscribedSeats, newlySubscribedSeats) => + int newlyAssignedSeats) => async newlySubscribedSeats => { - var plan = await pricingClient.GetPlanOrThrow(providerPlan.PlanType); + var subscription = await subscriberService.GetSubscriptionOrThrow(provider); - await paymentService.AdjustSeats( - provider, - plan, - currentlySubscribedSeats, - newlySubscribedSeats); + var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType); + + var item = subscription.Items.First(item => item.Price.Id == priceId); + + await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions + { + Items = [ + new SubscriptionItemOptions + { + Id = item.Id, + Price = priceId, + Quantity = newlySubscribedSeats + } + ] + }); var newlyPurchasedSeats = newlySubscribedSeats > providerPlan.SeatMinimum ? newlySubscribedSeats - providerPlan.SeatMinimum diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderPriceAdapter.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderPriceAdapter.cs new file mode 100644 index 0000000000..4cc0711ec9 --- /dev/null +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderPriceAdapter.cs @@ -0,0 +1,133 @@ +// ReSharper disable SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault +#nullable enable +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing; +using Bit.Core.Billing.Enums; +using Stripe; + +namespace Bit.Commercial.Core.Billing; + +public static class ProviderPriceAdapter +{ + public static class MSP + { + public static class Active + { + public const string Enterprise = "provider-portal-enterprise-monthly-2025"; + public const string Teams = "provider-portal-teams-monthly-2025"; + } + + public static class Legacy + { + public const string Enterprise = "password-manager-provider-portal-enterprise-monthly-2024"; + public const string Teams = "password-manager-provider-portal-teams-monthly-2024"; + public static readonly List List = [Enterprise, Teams]; + } + } + + public static class BusinessUnit + { + public static class Active + { + public const string Annually = "business-unit-portal-enterprise-annually-2025"; + public const string Monthly = "business-unit-portal-enterprise-monthly-2025"; + } + + public static class Legacy + { + public const string Annually = "password-manager-provider-portal-enterprise-annually-2024"; + public const string Monthly = "password-manager-provider-portal-enterprise-monthly-2024"; + public static readonly List List = [Annually, Monthly]; + } + } + + /// + /// Uses the 's and to determine + /// whether the is on active or legacy pricing and then returns a Stripe price ID for the provided + /// based on that determination. + /// + /// The provider to get the Stripe price ID for. + /// The provider's subscription. + /// The plan type correlating to the desired Stripe price ID. + /// A Stripe ID. + /// Thrown when the provider's type is not or . + /// Thrown when the provided does not relate to a Stripe price ID. + public static string GetPriceId( + Provider provider, + Subscription subscription, + PlanType planType) + { + var priceIds = subscription.Items.Select(item => item.Price.Id); + + var invalidPlanType = + new BillingException(message: $"PlanType {planType} does not have an associated provider price in Stripe"); + + return provider.Type switch + { + ProviderType.Msp => MSP.Legacy.List.Intersect(priceIds).Any() + ? planType switch + { + PlanType.TeamsMonthly => MSP.Legacy.Teams, + PlanType.EnterpriseMonthly => MSP.Legacy.Enterprise, + _ => throw invalidPlanType + } + : planType switch + { + PlanType.TeamsMonthly => MSP.Active.Teams, + PlanType.EnterpriseMonthly => MSP.Active.Enterprise, + _ => throw invalidPlanType + }, + ProviderType.MultiOrganizationEnterprise => BusinessUnit.Legacy.List.Intersect(priceIds).Any() + ? planType switch + { + PlanType.EnterpriseAnnually => BusinessUnit.Legacy.Annually, + PlanType.EnterpriseMonthly => BusinessUnit.Legacy.Monthly, + _ => throw invalidPlanType + } + : planType switch + { + PlanType.EnterpriseAnnually => BusinessUnit.Active.Annually, + PlanType.EnterpriseMonthly => BusinessUnit.Active.Monthly, + _ => throw invalidPlanType + }, + _ => throw new BillingException( + $"ProviderType {provider.Type} does not have any associated provider price IDs") + }; + } + + /// + /// Uses the 's to return the active Stripe price ID for the provided + /// . + /// + /// The provider to get the Stripe price ID for. + /// The plan type correlating to the desired Stripe price ID. + /// A Stripe ID. + /// Thrown when the provider's type is not or . + /// Thrown when the provided does not relate to a Stripe price ID. + public static string GetActivePriceId( + Provider provider, + PlanType planType) + { + var invalidPlanType = + new BillingException(message: $"PlanType {planType} does not have an associated provider price in Stripe"); + + return provider.Type switch + { + ProviderType.Msp => planType switch + { + PlanType.TeamsMonthly => MSP.Active.Teams, + PlanType.EnterpriseMonthly => MSP.Active.Enterprise, + _ => throw invalidPlanType + }, + ProviderType.MultiOrganizationEnterprise => planType switch + { + PlanType.EnterpriseAnnually => BusinessUnit.Active.Annually, + PlanType.EnterpriseMonthly => BusinessUnit.Active.Monthly, + _ => throw invalidPlanType + }, + _ => throw new BillingException( + $"ProviderType {provider.Type} does not have any associated provider price IDs") + }; + } +} diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index 71a150a546..ab1000d631 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -4,6 +4,7 @@ using Bit.Commercial.Core.Billing; using Bit.Commercial.Core.Billing.Models; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; @@ -115,6 +116,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.MultiOrganizationEnterprise; + var providerPlanRepository = sutProvider.GetDependency(); var existingPlan = new ProviderPlan { @@ -132,10 +135,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetPlanOrThrow(existingPlan.PlanType) .Returns(StaticStore.GetPlan(existingPlan.PlanType)); - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.ProviderSubscriptionGetAsync( - Arg.Is(provider.GatewaySubscriptionId), - Arg.Is(provider.Id)) + sutProvider.GetDependency().GetSubscriptionOrThrow(provider) .Returns(new Subscription { Id = provider.GatewaySubscriptionId, @@ -158,7 +158,7 @@ public class ProviderBillingServiceTests }); var command = - new ChangeProviderPlanCommand(providerPlanId, PlanType.EnterpriseMonthly, provider.GatewaySubscriptionId); + new ChangeProviderPlanCommand(provider, providerPlanId, PlanType.EnterpriseMonthly); sutProvider.GetDependency().GetPlanOrThrow(command.NewPlan) .Returns(StaticStore.GetPlan(command.NewPlan)); @@ -170,6 +170,8 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1) .ReplaceAsync(Arg.Is(p => p.PlanType == PlanType.EnterpriseMonthly)); + var stripeAdapter = sutProvider.GetDependency(); + await stripeAdapter.Received(1) .SubscriptionUpdateAsync( Arg.Is(provider.GatewaySubscriptionId), @@ -405,6 +407,23 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } }, + new SubscriptionItem + { + Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise } + } + ] + } + }; + + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); + // 50 seats currently assigned with a seat minimum of 100 var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); @@ -427,11 +446,9 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().AdjustSeats( - Arg.Any(), - Arg.Any(), - Arg.Any(), - Arg.Any()); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().SubscriptionUpdateAsync( + Arg.Any(), + Arg.Any()); await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( pPlan => pPlan.AllocatedSeats == 60)); @@ -474,6 +491,23 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } }, + new SubscriptionItem + { + Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise } + } + ] + } + }; + + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); + // 95 seats currently assigned with a seat minimum of 100 var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); @@ -496,11 +530,12 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 95 current + 10 seat scale = 105 seats, 5 above the minimum - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - providerPlan.SeatMinimum!.Value, - 105); + await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + provider.GatewaySubscriptionId, + Arg.Is( + options => + options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams && + options.Items.First().Quantity == 105)); // 105 total seats - 100 minimum = 5 purchased seats await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( @@ -544,6 +579,23 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } }, + new SubscriptionItem + { + Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise } + } + ] + } + }; + + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); + // 110 seats currently assigned with a seat minimum of 100 var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); @@ -566,11 +618,12 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 110 current + 10 seat scale up = 120 seats - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - 110, - 120); + await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + provider.GatewaySubscriptionId, + Arg.Is( + options => + options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams && + options.Items.First().Quantity == 120)); // 120 total seats - 100 seat minimum = 20 purchased seats await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( @@ -614,6 +667,23 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } }, + new SubscriptionItem + { + Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise } + } + ] + } + }; + + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); + // 110 seats currently assigned with a seat minimum of 100 var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); @@ -636,11 +706,12 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, -30); // 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum. - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - 110, - providerPlan.SeatMinimum!.Value); + await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + provider.GatewaySubscriptionId, + Arg.Is( + options => + options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams && + options.Items.First().Quantity == providerPlan.SeatMinimum!.Value)); // Being below the seat minimum means no purchased seats. await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( @@ -977,6 +1048,7 @@ public class ProviderBillingServiceTests SutProvider sutProvider, Provider provider) { + provider.Type = ProviderType.Msp; provider.GatewaySubscriptionId = null; var customer = new Customer @@ -1020,9 +1092,6 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id) .Returns(providerPlans); - var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); - var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; sutProvider.GetDependency() @@ -1045,9 +1114,9 @@ public class ProviderBillingServiceTests sub.Customer == "customer_id" && sub.DaysUntilDue == 30 && sub.Items.Count == 2 && - sub.Items.ElementAt(0).Price == teamsPlan.PasswordManager.StripeProviderPortalSeatPlanId && + sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams && sub.Items.ElementAt(0).Quantity == 100 && - sub.Items.ElementAt(1).Price == enterprisePlan.PasswordManager.StripeProviderPortalSeatPlanId && + sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise && sub.Items.ElementAt(1).Quantity == 100 && sub.Metadata["providerId"] == provider.Id.ToString() && sub.OffSession == true && @@ -1069,8 +1138,7 @@ public class ProviderBillingServiceTests { // Arrange var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.TeamsMonthly, -10), (PlanType.EnterpriseMonthly, 50) @@ -1089,6 +1157,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.Msp; + var stripeAdapter = sutProvider.GetDependency(); var providerPlanRepository = sutProvider.GetDependency(); @@ -1118,9 +1188,7 @@ public class ProviderBillingServiceTests } }; - stripeAdapter.ProviderSubscriptionGetAsync( - provider.GatewaySubscriptionId, - provider.Id).Returns(subscription); + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); var providerPlans = new List { @@ -1137,8 +1205,7 @@ public class ProviderBillingServiceTests providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.EnterpriseMonthly, 30), (PlanType.TeamsMonthly, 20) @@ -1170,6 +1237,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.Msp; + var stripeAdapter = sutProvider.GetDependency(); var providerPlanRepository = sutProvider.GetDependency(); @@ -1199,7 +1268,7 @@ public class ProviderBillingServiceTests } }; - stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); var providerPlans = new List { @@ -1216,8 +1285,7 @@ public class ProviderBillingServiceTests providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.EnterpriseMonthly, 70), (PlanType.TeamsMonthly, 50) @@ -1249,6 +1317,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.Msp; + var stripeAdapter = sutProvider.GetDependency(); var providerPlanRepository = sutProvider.GetDependency(); @@ -1278,7 +1348,7 @@ public class ProviderBillingServiceTests } }; - stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); var providerPlans = new List { @@ -1295,8 +1365,7 @@ public class ProviderBillingServiceTests providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.EnterpriseMonthly, 60), (PlanType.TeamsMonthly, 60) @@ -1322,6 +1391,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.Msp; + var stripeAdapter = sutProvider.GetDependency(); var providerPlanRepository = sutProvider.GetDependency(); @@ -1351,7 +1422,7 @@ public class ProviderBillingServiceTests } }; - stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); var providerPlans = new List { @@ -1368,8 +1439,7 @@ public class ProviderBillingServiceTests providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.EnterpriseMonthly, 80), (PlanType.TeamsMonthly, 80) @@ -1401,6 +1471,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.Msp; + var stripeAdapter = sutProvider.GetDependency(); var providerPlanRepository = sutProvider.GetDependency(); @@ -1430,7 +1502,7 @@ public class ProviderBillingServiceTests } }; - stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); var providerPlans = new List { @@ -1447,8 +1519,7 @@ public class ProviderBillingServiceTests providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.EnterpriseMonthly, 70), (PlanType.TeamsMonthly, 30) diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderPriceAdapterTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderPriceAdapterTests.cs new file mode 100644 index 0000000000..4fce78c05a --- /dev/null +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderPriceAdapterTests.cs @@ -0,0 +1,151 @@ +using Bit.Commercial.Core.Billing; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Enums; +using Stripe; +using Xunit; + +namespace Bit.Commercial.Core.Test.Billing; + +public class ProviderPriceAdapterTests +{ + [Theory] + [InlineData("password-manager-provider-portal-enterprise-monthly-2024", PlanType.EnterpriseMonthly)] + [InlineData("password-manager-provider-portal-teams-monthly-2024", PlanType.TeamsMonthly)] + public void GetPriceId_MSP_Legacy_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.Msp + }; + + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = priceId } } + ] + } + }; + + var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType); + + Assert.Equal(result, priceId); + } + + [Theory] + [InlineData("provider-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)] + [InlineData("provider-portal-teams-monthly-2025", PlanType.TeamsMonthly)] + public void GetPriceId_MSP_Active_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.Msp + }; + + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = priceId } } + ] + } + }; + + var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType); + + Assert.Equal(result, priceId); + } + + [Theory] + [InlineData("password-manager-provider-portal-enterprise-annually-2024", PlanType.EnterpriseAnnually)] + [InlineData("password-manager-provider-portal-enterprise-monthly-2024", PlanType.EnterpriseMonthly)] + public void GetPriceId_BusinessUnit_Legacy_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.MultiOrganizationEnterprise + }; + + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = priceId } } + ] + } + }; + + var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType); + + Assert.Equal(result, priceId); + } + + [Theory] + [InlineData("business-unit-portal-enterprise-annually-2025", PlanType.EnterpriseAnnually)] + [InlineData("business-unit-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)] + public void GetPriceId_BusinessUnit_Active_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.MultiOrganizationEnterprise + }; + + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = priceId } } + ] + } + }; + + var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType); + + Assert.Equal(result, priceId); + } + + [Theory] + [InlineData("provider-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)] + [InlineData("provider-portal-teams-monthly-2025", PlanType.TeamsMonthly)] + public void GetActivePriceId_MSP_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.Msp + }; + + var result = ProviderPriceAdapter.GetActivePriceId(provider, planType); + + Assert.Equal(result, priceId); + } + + [Theory] + [InlineData("business-unit-portal-enterprise-annually-2025", PlanType.EnterpriseAnnually)] + [InlineData("business-unit-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)] + public void GetActivePriceId_BusinessUnit_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.MultiOrganizationEnterprise + }; + + var result = ProviderPriceAdapter.GetActivePriceId(provider, planType); + + Assert.Equal(result, priceId); + } +} diff --git a/src/Admin/AdminConsole/Controllers/ProvidersController.cs b/src/Admin/AdminConsole/Controllers/ProvidersController.cs index c38bb64419..0b1e4035df 100644 --- a/src/Admin/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Admin/AdminConsole/Controllers/ProvidersController.cs @@ -300,8 +300,7 @@ public class ProvidersController : Controller { case ProviderType.Msp: var updateMspSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (Plan: PlanType.TeamsMonthly, SeatsMinimum: model.TeamsMonthlySeatMinimum), (Plan: PlanType.EnterpriseMonthly, SeatsMinimum: model.EnterpriseMonthlySeatMinimum) @@ -314,15 +313,14 @@ public class ProvidersController : Controller // 1. Change the plan and take over any old values. var changeMoePlanCommand = new ChangeProviderPlanCommand( + provider, existingMoePlan.Id, - model.Plan!.Value, - provider.GatewaySubscriptionId); + model.Plan!.Value); await _providerBillingService.ChangePlan(changeMoePlanCommand); // 2. Update the seat minimums. var updateMoeSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (Plan: model.Plan!.Value, SeatsMinimum: model.EnterpriseMinimumSeats!.Value) ]); diff --git a/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs b/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs index b5c4383556..384cfca1d1 100644 --- a/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs +++ b/src/Core/Billing/Migration/Services/Implementations/ProviderMigrator.cs @@ -309,8 +309,7 @@ public class ProviderMigrator( .SeatMinimum ?? 0; var updateSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (Plan: PlanType.EnterpriseMonthly, SeatsMinimum: enterpriseSeatMinimum), (Plan: PlanType.TeamsMonthly, SeatsMinimum: teamsSeatMinimum) diff --git a/src/Core/Billing/Models/StaticStore/Plan.cs b/src/Core/Billing/Models/StaticStore/Plan.cs index 5dbcd7ddc4..17aa78aa06 100644 --- a/src/Core/Billing/Models/StaticStore/Plan.cs +++ b/src/Core/Billing/Models/StaticStore/Plan.cs @@ -75,6 +75,7 @@ public abstract record Plan // Seats public string StripePlanId { get; init; } public string StripeSeatPlanId { get; init; } + [Obsolete("No longer used to retrieve a provider's price ID. Use ProviderPriceAdapter instead.")] public string StripeProviderPortalSeatPlanId { get; init; } public decimal BasePrice { get; init; } public decimal SeatPrice { get; init; } diff --git a/src/Core/Billing/Services/Contracts/ChangeProviderPlansCommand.cs b/src/Core/Billing/Services/Contracts/ChangeProviderPlansCommand.cs index 3e8fffdd11..385782c8ad 100644 --- a/src/Core/Billing/Services/Contracts/ChangeProviderPlansCommand.cs +++ b/src/Core/Billing/Services/Contracts/ChangeProviderPlansCommand.cs @@ -1,8 +1,9 @@ -using Bit.Core.Billing.Enums; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Enums; namespace Bit.Core.Billing.Services.Contracts; public record ChangeProviderPlanCommand( + Provider Provider, Guid ProviderPlanId, - PlanType NewPlan, - string GatewaySubscriptionId); + PlanType NewPlan); diff --git a/src/Core/Billing/Services/Contracts/UpdateProviderSeatMinimumsCommand.cs b/src/Core/Billing/Services/Contracts/UpdateProviderSeatMinimumsCommand.cs index 86a596ffb6..2d2535b60a 100644 --- a/src/Core/Billing/Services/Contracts/UpdateProviderSeatMinimumsCommand.cs +++ b/src/Core/Billing/Services/Contracts/UpdateProviderSeatMinimumsCommand.cs @@ -1,10 +1,10 @@ -using Bit.Core.Billing.Enums; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Enums; namespace Bit.Core.Billing.Services.Contracts; -/// The ID of the provider to update the seat minimums for. +/// The provider to update the seat minimums for. /// The new seat minimums for the provider. public record UpdateProviderSeatMinimumsCommand( - Guid Id, - string GatewaySubscriptionId, + Provider Provider, IReadOnlyCollection<(PlanType Plan, int SeatsMinimum)> Configuration); diff --git a/src/Core/Models/Business/ProviderSubscriptionUpdate.cs b/src/Core/Models/Business/ProviderSubscriptionUpdate.cs deleted file mode 100644 index 1fd833ca1f..0000000000 --- a/src/Core/Models/Business/ProviderSubscriptionUpdate.cs +++ /dev/null @@ -1,62 +0,0 @@ -using Bit.Core.Billing; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; -using Stripe; -using Plan = Bit.Core.Models.StaticStore.Plan; - -namespace Bit.Core.Models.Business; - -public class ProviderSubscriptionUpdate : SubscriptionUpdate -{ - private readonly string _planId; - private readonly int _previouslyPurchasedSeats; - private readonly int _newlyPurchasedSeats; - - protected override List PlanIds => [_planId]; - - public ProviderSubscriptionUpdate( - Plan plan, - int previouslyPurchasedSeats, - int newlyPurchasedSeats) - { - if (!plan.Type.SupportsConsolidatedBilling()) - { - throw new BillingException( - message: $"Cannot create a {nameof(ProviderSubscriptionUpdate)} for {nameof(PlanType)} that doesn't support consolidated billing"); - } - - _planId = plan.PasswordManager.StripeProviderPortalSeatPlanId; - _previouslyPurchasedSeats = previouslyPurchasedSeats; - _newlyPurchasedSeats = newlyPurchasedSeats; - } - - public override List RevertItemsOptions(Subscription subscription) - { - var subscriptionItem = FindSubscriptionItem(subscription, _planId); - - return - [ - new SubscriptionItemOptions - { - Id = subscriptionItem.Id, - Price = _planId, - Quantity = _previouslyPurchasedSeats - } - ]; - } - - public override List UpgradeItemsOptions(Subscription subscription) - { - var subscriptionItem = FindSubscriptionItem(subscription, _planId); - - return - [ - new SubscriptionItemOptions - { - Id = subscriptionItem.Id, - Price = _planId, - Quantity = _newlyPurchasedSeats - } - ]; - } -} diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index e3495c0e65..bd7efdbad4 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -1,5 +1,4 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Api.Requests.Accounts; using Bit.Core.Billing.Models.Api.Requests.Organizations; @@ -25,11 +24,6 @@ public interface IPaymentService int? newlyPurchasedAdditionalSecretsManagerServiceAccounts, int newlyPurchasedAdditionalStorage); Task AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats); - Task AdjustSeats( - Provider provider, - Plan plan, - int currentlySubscribedSeats, - int newlySubscribedSeats); Task AdjustSmSeatsAsync(Organization organization, Plan plan, int additionalSeats); Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId); diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index cdcd14ca90..d8889bca26 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1,5 +1,4 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; @@ -251,18 +250,6 @@ public class StripePaymentService : IPaymentService public Task AdjustSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats) => FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats)); - public Task AdjustSeats( - Provider provider, - StaticStore.Plan plan, - int currentlySubscribedSeats, - int newlySubscribedSeats) - => FinalizeSubscriptionChangeAsync( - provider, - new ProviderSubscriptionUpdate( - plan, - currentlySubscribedSeats, - newlySubscribedSeats)); - public Task AdjustSmSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats) => FinalizeSubscriptionChangeAsync( organization, From 0f0c3a4e5ad93708b634bac489dffff67e91d664 Mon Sep 17 00:00:00 2001 From: Vijay Oommen Date: Thu, 3 Apr 2025 08:35:29 -0500 Subject: [PATCH 11/15] [PM-19423] Update an existing org with license should set UseRiskInsights flag (#5539) --- src/Core/AdminConsole/Entities/Organization.cs | 1 + .../Models/Data/Organizations/SelfHostedOrganizationDetails.cs | 3 ++- src/Core/Billing/Licenses/LicenseConstants.cs | 1 + .../Implementations/OrganizationLicenseClaimsFactory.cs | 1 + 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Core/AdminConsole/Entities/Organization.cs b/src/Core/AdminConsole/Entities/Organization.cs index 54661e22a7..e91f1ede29 100644 --- a/src/Core/AdminConsole/Entities/Organization.cs +++ b/src/Core/AdminConsole/Entities/Organization.cs @@ -313,5 +313,6 @@ public class Organization : ITableObject, IStorableSubscriber, IRevisable, UseSecretsManager = license.UseSecretsManager; SmSeats = license.SmSeats; SmServiceAccounts = license.SmServiceAccounts; + UseRiskInsights = license.UseRiskInsights; } } diff --git a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs index c53ac8745c..ab2dfd7e0e 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs @@ -148,7 +148,8 @@ public class SelfHostedOrganizationDetails : Organization LimitCollectionDeletion = LimitCollectionDeletion, LimitItemDeletion = LimitItemDeletion, AllowAdminAccessToAllCollectionItems = AllowAdminAccessToAllCollectionItems, - Status = Status + Status = Status, + UseRiskInsights = UseRiskInsights, }; } } diff --git a/src/Core/Billing/Licenses/LicenseConstants.cs b/src/Core/Billing/Licenses/LicenseConstants.cs index 564019affc..50510914a5 100644 --- a/src/Core/Billing/Licenses/LicenseConstants.cs +++ b/src/Core/Billing/Licenses/LicenseConstants.cs @@ -36,6 +36,7 @@ public static class OrganizationLicenseConstants public const string SmServiceAccounts = nameof(SmServiceAccounts); public const string LimitCollectionCreationDeletion = nameof(LimitCollectionCreationDeletion); public const string AllowAdminAccessToAllCollectionItems = nameof(AllowAdminAccessToAllCollectionItems); + public const string UseRiskInsights = nameof(UseRiskInsights); public const string Expires = nameof(Expires); public const string Refresh = nameof(Refresh); public const string ExpirationWithoutGracePeriod = nameof(ExpirationWithoutGracePeriod); diff --git a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs index e436102012..62e1889564 100644 --- a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs +++ b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs @@ -47,6 +47,7 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory Date: Thu, 3 Apr 2025 10:03:31 -0500 Subject: [PATCH 12/15] Added more tests to catch more use cases and fix bugs. (#5598) --- .../v1/RestoreOrganizationUserCommand.cs | 29 ++- .../RestoreOrganizationUserCommandTests.cs | 212 +++++++++++++++++- 2 files changed, 229 insertions(+), 12 deletions(-) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs index 3d4b0fba5c..f122463a98 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs @@ -87,7 +87,10 @@ public class RestoreOrganizationUserCommand( .twoFactorIsEnabled; } - await CheckUserForOtherFreeOrganizationOwnershipAsync(organizationUser); + if (organization.PlanType == PlanType.Free) + { + await CheckUserForOtherFreeOrganizationOwnershipAsync(organizationUser); + } await CheckPoliciesBeforeRestoreAsync(organizationUser, userTwoFactorIsEnabled); @@ -100,7 +103,7 @@ public class RestoreOrganizationUserCommand( private async Task CheckUserForOtherFreeOrganizationOwnershipAsync(OrganizationUser organizationUser) { - var relatedOrgUsersFromOtherOrgs = await organizationUserRepository.GetManyByUserAsync(organizationUser.UserId.Value); + var relatedOrgUsersFromOtherOrgs = await organizationUserRepository.GetManyByUserAsync(organizationUser.UserId!.Value); var otherOrgs = await organizationRepository.GetManyByUserIdAsync(organizationUser.UserId.Value); var orgOrgUserDict = relatedOrgUsersFromOtherOrgs @@ -110,13 +113,16 @@ public class RestoreOrganizationUserCommand( CheckForOtherFreeOrganizationOwnership(organizationUser, orgOrgUserDict); } - private async Task> GetRelatedOrganizationUsersAndOrganizations( - IEnumerable organizationUsers) + private async Task> GetRelatedOrganizationUsersAndOrganizationsAsync( + List organizationUsers) { - var allUserIds = organizationUsers.Select(x => x.UserId.Value); + var allUserIds = organizationUsers + .Where(x => x.UserId.HasValue) + .Select(x => x.UserId.Value); var otherOrganizationUsers = (await organizationUserRepository.GetManyByManyUsersAsync(allUserIds)) - .Where(x => organizationUsers.Any(y => y.Id == x.Id) == false); + .Where(x => organizationUsers.Any(y => y.Id == x.Id) == false) + .ToArray(); var otherOrgs = await organizationRepository.GetManyByIdsAsync(otherOrganizationUsers .Select(x => x.OrganizationId) @@ -130,7 +136,9 @@ public class RestoreOrganizationUserCommand( Dictionary otherOrgUsersAndOrgs) { var ownerOrAdminList = new[] { OrganizationUserType.Owner, OrganizationUserType.Admin }; - if (otherOrgUsersAndOrgs.Any(x => + + if (ownerOrAdminList.Any(x => organizationUser.Type == x) && + otherOrgUsersAndOrgs.Any(x => x.Key.UserId == organizationUser.UserId && ownerOrAdminList.Any(userType => userType == x.Key.Type) && x.Key.Status == OrganizationUserStatusType.Confirmed && @@ -170,7 +178,7 @@ public class RestoreOrganizationUserCommand( var organizationUsersTwoFactorEnabled = await twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync( filteredUsers.Where(ou => ou.UserId.HasValue).Select(ou => ou.UserId.Value)); - var orgUsersAndOrgs = await GetRelatedOrganizationUsersAndOrganizations(filteredUsers); + var orgUsersAndOrgs = await GetRelatedOrganizationUsersAndOrganizationsAsync(filteredUsers); var result = new List>(); @@ -201,7 +209,10 @@ public class RestoreOrganizationUserCommand( await CheckPoliciesBeforeRestoreAsync(organizationUser, twoFactorIsEnabled); - CheckForOtherFreeOrganizationOwnership(organizationUser, orgUsersAndOrgs); + if (organization.PlanType == PlanType.Free) + { + CheckForOtherFreeOrganizationOwnership(organizationUser, orgUsersAndOrgs); + } var status = OrganizationService.GetPriorActiveOrganizationUserStatusType(organizationUser); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs index 726664849d..f91ca779a8 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs @@ -471,10 +471,11 @@ public class RestoreOrganizationUserCommandTests Organization organization, Organization otherOrganization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, - [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.Owner)] OrganizationUser organizationUser, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserOwnerFromDifferentOrg, SutProvider sutProvider) { + organization.PlanType = PlanType.Free; organizationUser.Email = null; // this is required to mock that the user as had already been confirmed before the revoke orgUserOwnerFromDifferentOrg.UserId = organizationUser.UserId; @@ -506,6 +507,107 @@ public class RestoreOrganizationUserCommandTests Assert.Equal("User is an owner/admin of another free organization. Please have them upgrade to a paid plan to restore their account.", exception.Message); } + [Theory, BitAutoData] + public async Task RestoreUser_WhenUserOwningAnotherFreeOrganizationAndIsOnlyAUserInCurrentOrg_ThenUserShouldBeRestored( + Organization organization, + Organization otherOrganization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserOwnerFromDifferentOrg, + SutProvider sutProvider) + { + organization.PlanType = PlanType.Free; + organizationUser.Email = null; // this is required to mock that the user as had already been confirmed before the revoke + + orgUserOwnerFromDifferentOrg.UserId = organizationUser.UserId; + otherOrganization.Id = orgUserOwnerFromDifferentOrg.OrganizationId; + otherOrganization.PlanType = PlanType.Free; + + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + + var organizationUserRepository = sutProvider.GetDependency(); + organizationUserRepository + .GetManyByUserAsync(organizationUser.UserId.Value) + .Returns([orgUserOwnerFromDifferentOrg]); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(organizationUser.UserId.Value) + .Returns([otherOrganization]); + + sutProvider.GetDependency() + .GetPoliciesApplicableToUserAsync(organizationUser.UserId.Value, PolicyType.TwoFactorAuthentication, + Arg.Any()) + .Returns([ + new OrganizationUserPolicyDetails + { + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.TwoFactorAuthentication + } + ]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Is>(i => i.Contains(organizationUser.UserId.Value))) + .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) }); + + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + + await organizationUserRepository + .Received(1) + .RestoreAsync(organizationUser.Id, + Arg.Is(x => x != OrganizationUserStatusType.Revoked)); + } + + [Theory, BitAutoData] + public async Task RestoreUser_WhenUserOwningAnotherFreeOrganizationAndCurrentOrgIsNotFree_ThenUserShouldBeRestored( + Organization organization, + Organization otherOrganization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.Owner)] OrganizationUser organizationUser, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserOwnerFromDifferentOrg, + SutProvider sutProvider) + { + organization.PlanType = PlanType.EnterpriseAnnually2023; + + organizationUser.Email = null; // this is required to mock that the user as had already been confirmed before the revoke + + orgUserOwnerFromDifferentOrg.UserId = organizationUser.UserId; + otherOrganization.Id = orgUserOwnerFromDifferentOrg.OrganizationId; + otherOrganization.PlanType = PlanType.Free; + + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + + var organizationUserRepository = sutProvider.GetDependency(); + organizationUserRepository + .GetManyByUserAsync(organizationUser.UserId.Value) + .Returns([orgUserOwnerFromDifferentOrg]); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(organizationUser.UserId.Value) + .Returns([otherOrganization]); + + sutProvider.GetDependency() + .GetPoliciesApplicableToUserAsync(organizationUser.UserId.Value, PolicyType.TwoFactorAuthentication, + Arg.Any()) + .Returns([ + new OrganizationUserPolicyDetails + { + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.TwoFactorAuthentication + } + ]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Is>(i => i.Contains(organizationUser.UserId.Value))) + .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) }); + + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + + await organizationUserRepository + .Received(1) + .RestoreAsync(organizationUser.Id, + Arg.Is(x => x != OrganizationUserStatusType.Revoked)); + } + [Theory, BitAutoData] public async Task RestoreUsers_Success(Organization organization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, @@ -612,7 +714,7 @@ public class RestoreOrganizationUserCommandTests [Theory, BitAutoData] public async Task RestoreUsers_UserOwnsAnotherFreeOrganization_BlocksOwnerUserFromBeingRestored(Organization organization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, - [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.Owner)] OrganizationUser orgUser1, [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser2, [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser3, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserFromOtherOrg, @@ -637,7 +739,7 @@ public class RestoreOrganizationUserCommandTests organizationUserRepository .GetManyAsync(Arg.Is>(ids => ids.Contains(orgUser1.Id) && ids.Contains(orgUser2.Id) && ids.Contains(orgUser3.Id))) - .Returns(new[] { orgUser1, orgUser2, orgUser3 }); + .Returns([orgUser1, orgUser2, orgUser3]); userRepository.GetByIdAsync(orgUser2.UserId!.Value).Returns(new User { Email = "test@example.com" }); @@ -674,6 +776,110 @@ public class RestoreOrganizationUserCommandTests .RestoreAsync(orgUser1.Id, OrganizationUserStatusType.Confirmed); } + [Theory, BitAutoData] + public async Task RestoreUsers_UserOwnsAnotherFreeOrganizationButReactivatingOrgIsPaid_RestoresUser(Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.Owner)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserFromOtherOrg, + Organization otherOrganization, + SutProvider sutProvider) + { + // Arrange + organization.PlanType = PlanType.EnterpriseAnnually2023; + + RestoreUser_Setup(organization, owner, orgUser1, sutProvider); + var organizationUserRepository = sutProvider.GetDependency(); + var policyService = sutProvider.GetDependency(); + var userService = Substitute.For(); + + orgUser1.OrganizationId = organization.Id; + + orgUserFromOtherOrg.UserId = orgUser1.UserId; + + otherOrganization.Id = orgUserFromOtherOrg.OrganizationId; + otherOrganization.PlanType = PlanType.Free; + + organizationUserRepository + .GetManyAsync(Arg.Is>(ids => ids.Contains(orgUser1.Id))) + .Returns([orgUser1]); + + organizationUserRepository + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([orgUserFromOtherOrg]); + + sutProvider.GetDependency() + .GetManyByIdsAsync(Arg.Is>(ids => ids.Contains(orgUserFromOtherOrg.OrganizationId))) + .Returns([otherOrganization]); + + + // Setup 2FA policy + policyService.GetPoliciesApplicableToUserAsync(Arg.Any(), PolicyType.TwoFactorAuthentication, Arg.Any()) + .Returns([new OrganizationUserPolicyDetails { OrganizationId = organization.Id, PolicyType = PolicyType.TwoFactorAuthentication }]); + + // User1 has 2FA, User2 doesn't + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Is>(ids => ids.Contains(orgUser1.UserId!.Value))) + .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> + { + (orgUser1.UserId!.Value, true) + }); + + // Act + var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService); + + // Assert + Assert.Single(result); + Assert.Equal(string.Empty, result[0].Item2); + await organizationUserRepository + .Received(1) + .RestoreAsync(orgUser1.Id, Arg.Is(x => x != OrganizationUserStatusType.Revoked)); + } + + [Theory] + [BitAutoData] + public async Task RestoreUsers_UserOwnsAnotherOrganizationButIsOnlyUserOfCurrentOrganization_UserShouldBeRestored( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.User)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserFromOtherOrg, + Organization otherOrganization, + SutProvider sutProvider) + { + // Arrange + organization.PlanType = PlanType.Free; + + RestoreUser_Setup(organization, owner, orgUser1, sutProvider); + var organizationUserRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + orgUser1.OrganizationId = organization.Id; + + orgUserFromOtherOrg.UserId = orgUser1.UserId; + + otherOrganization.Id = orgUserFromOtherOrg.OrganizationId; + otherOrganization.PlanType = PlanType.Free; + + organizationUserRepository + .GetManyAsync(Arg.Is>(ids => ids.Contains(orgUser1.Id))) + .Returns([orgUser1]); + + organizationUserRepository + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([orgUserFromOtherOrg]); + + sutProvider.GetDependency().GetPoliciesApplicableToUserAsync(Arg.Any(), PolicyType.TwoFactorAuthentication, Arg.Any()) + .Returns([new OrganizationUserPolicyDetails { OrganizationId = organization.Id, PolicyType = PolicyType.TwoFactorAuthentication }]); + + // Act + var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService); + + Assert.Single(result); + Assert.Equal(string.Empty, result[0].Item2); + await organizationUserRepository + .Received(1) + .RestoreAsync(orgUser1.Id, Arg.Is(x => x != OrganizationUserStatusType.Revoked)); + } + private static void RestoreUser_Setup( Organization organization, OrganizationUser? requestingOrganizationUser, From 33f5a19b9981d2cbb189e60d992b86e78c4b6737 Mon Sep 17 00:00:00 2001 From: Brant DeBow <125889545+brant-livefront@users.noreply.github.com> Date: Thu, 3 Apr 2025 11:23:00 -0400 Subject: [PATCH 13/15] [PM-17562] Add Dapper and EF Repositories For Ogranization Integrations and Configurations (#5589) * [PM-17562] Add Dapper and EF Repositories For Ogranization Integrations and Configurations * Updated with changes from PR comments --- ...nizationIntegrationConfigurationDetails.cs | 64 +++++++++++++++ ...ationIntegrationConfigurationRepository.cs | 13 +++ .../IOrganizationIntegrationRepository.cs | 7 ++ ...ationIntegrationConfigurationRepository.cs | 43 ++++++++++ .../OrganizationIntegrationRepository.cs | 16 ++++ ...ationIntegrationConfigurationRepository.cs | 33 ++++++++ .../OrganizationIntegrationRepository.cs | 14 ++++ ...tTypeOrganizationIdIntegrationTypeQuery.cs | 39 +++++++++ ...ityFrameworkServiceCollectionExtensions.cs | 2 + .../Repositories/DatabaseContext.cs | 2 + ...ionIntegrationConfigurationDetailsTests.cs | 82 +++++++++++++++++++ 11 files changed, 315 insertions(+) create mode 100644 src/Core/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetails.cs create mode 100644 src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs create mode 100644 src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs create mode 100644 src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs create mode 100644 src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs create mode 100644 src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs create mode 100644 src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs create mode 100644 src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs create mode 100644 test/Core.Test/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetailsTests.cs diff --git a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetails.cs new file mode 100644 index 0000000000..139a7aff25 --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetails.cs @@ -0,0 +1,64 @@ +using System.Text.Json.Nodes; +using Bit.Core.Enums; + +#nullable enable + +namespace Bit.Core.Models.Data.Organizations; + +public class OrganizationIntegrationConfigurationDetails +{ + public Guid Id { get; set; } + public Guid OrganizationIntegrationId { get; set; } + public IntegrationType IntegrationType { get; set; } + public EventType EventType { get; set; } + public string? Configuration { get; set; } + public string? IntegrationConfiguration { get; set; } + public string? Template { get; set; } + + public JsonObject MergedConfiguration + { + get + { + var integrationJson = IntegrationConfigurationJson; + + foreach (var kvp in ConfigurationJson) + { + integrationJson[kvp.Key] = kvp.Value?.DeepClone(); + } + + return integrationJson; + } + } + + private JsonObject ConfigurationJson + { + get + { + try + { + var configuration = Configuration ?? string.Empty; + return JsonNode.Parse(configuration) as JsonObject ?? new JsonObject(); + } + catch + { + return new JsonObject(); + } + } + } + + private JsonObject IntegrationConfigurationJson + { + get + { + try + { + var integration = IntegrationConfiguration ?? string.Empty; + return JsonNode.Parse(integration) as JsonObject ?? new JsonObject(); + } + catch + { + return new JsonObject(); + } + } + } +} diff --git a/src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs new file mode 100644 index 0000000000..516918fff9 --- /dev/null +++ b/src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs @@ -0,0 +1,13 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations; + +namespace Bit.Core.Repositories; + +public interface IOrganizationIntegrationConfigurationRepository : IRepository +{ + Task> GetConfigurationDetailsAsync( + Guid organizationId, + IntegrationType integrationType, + EventType eventType); +} diff --git a/src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs new file mode 100644 index 0000000000..cd7700c310 --- /dev/null +++ b/src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs @@ -0,0 +1,7 @@ +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.Repositories; + +public interface IOrganizationIntegrationRepository : IRepository +{ +} diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs new file mode 100644 index 0000000000..f3227dfd22 --- /dev/null +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs @@ -0,0 +1,43 @@ +using System.Data; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations; +using Bit.Core.Repositories; +using Bit.Core.Settings; +using Bit.Infrastructure.Dapper.Repositories; +using Dapper; +using Microsoft.Data.SqlClient; + +namespace Bit.Infrastructure.Dapper.AdminConsole.Repositories; + +public class OrganizationIntegrationConfigurationRepository : Repository, IOrganizationIntegrationConfigurationRepository +{ + public OrganizationIntegrationConfigurationRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public OrganizationIntegrationConfigurationRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetConfigurationDetailsAsync( + Guid organizationId, + IntegrationType integrationType, + EventType eventType) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationIntegrationConfigurationDetails_ReadManyByEventTypeOrganizationIdIntegrationType]", + new + { + EventType = eventType, + OrganizationId = organizationId, + IntegrationType = integrationType + }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } +} diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs new file mode 100644 index 0000000000..99f0e35378 --- /dev/null +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs @@ -0,0 +1,16 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Repositories; +using Bit.Core.Settings; + +namespace Bit.Infrastructure.Dapper.Repositories; + +public class OrganizationIntegrationRepository : Repository, IOrganizationIntegrationRepository +{ + public OrganizationIntegrationRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public OrganizationIntegrationRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } +} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs new file mode 100644 index 0000000000..f051830035 --- /dev/null +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs @@ -0,0 +1,33 @@ +using AutoMapper; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations; +using Bit.Core.Repositories; +using Bit.Infrastructure.EntityFramework.AdminConsole.Models; +using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.Infrastructure.EntityFramework.Repositories.Queries; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Infrastructure.EntityFramework.AdminConsole.Repositories; + +public class OrganizationIntegrationConfigurationRepository : Repository, IOrganizationIntegrationConfigurationRepository +{ + public OrganizationIntegrationConfigurationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, context => context.OrganizationIntegrationConfigurations) + { } + + public async Task> GetConfigurationDetailsAsync( + Guid organizationId, + IntegrationType integrationType, + EventType eventType) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery( + organizationId, eventType, integrationType + ); + return await query.Run(dbContext).ToListAsync(); + } + } +} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs new file mode 100644 index 0000000000..816ad3b25f --- /dev/null +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs @@ -0,0 +1,14 @@ +using AutoMapper; +using Bit.Core.Repositories; +using Bit.Infrastructure.EntityFramework.AdminConsole.Models; +using Bit.Infrastructure.EntityFramework.Repositories; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Infrastructure.EntityFramework.AdminConsole.Repositories; + +public class OrganizationIntegrationRepository : Repository, IOrganizationIntegrationRepository +{ + public OrganizationIntegrationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationIntegrations) + { } +} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs new file mode 100644 index 0000000000..1a54d6588a --- /dev/null +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs @@ -0,0 +1,39 @@ +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations; + +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery : IQuery +{ + private readonly Guid _organizationId; + private readonly EventType _eventType; + private readonly IntegrationType _integrationType; + + public OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery(Guid organizationId, EventType eventType, IntegrationType integrationType) + { + _organizationId = organizationId; + _eventType = eventType; + _integrationType = integrationType; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = from oic in dbContext.OrganizationIntegrationConfigurations + join oi in dbContext.OrganizationIntegrations on oic.OrganizationIntegrationId equals oi.Id into oioic + from oi in dbContext.OrganizationIntegrations + where oi.OrganizationId == _organizationId && + oi.Type == _integrationType && + oic.EventType == _eventType + select new OrganizationIntegrationConfigurationDetails() + { + Id = oic.Id, + OrganizationIntegrationId = oic.OrganizationIntegrationId, + IntegrationType = oi.Type, + EventType = oic.EventType, + Configuration = oic.Configuration, + IntegrationConfiguration = oi.Configuration, + Template = oic.Template + }; + return query; + } +} diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index 3f805bbe2c..ad6c7cf369 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -78,6 +78,8 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs index dd1b97b4f2..5c1c1bc87f 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs @@ -55,6 +55,8 @@ public class DatabaseContext : DbContext public DbSet OrganizationApiKeys { get; set; } public DbSet OrganizationSponsorships { get; set; } public DbSet OrganizationConnections { get; set; } + public DbSet OrganizationIntegrations { get; set; } + public DbSet OrganizationIntegrationConfigurations { get; set; } public DbSet OrganizationUsers { get; set; } public DbSet Policies { get; set; } public DbSet Providers { get; set; } diff --git a/test/Core.Test/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetailsTests.cs b/test/Core.Test/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetailsTests.cs new file mode 100644 index 0000000000..99a11903b4 --- /dev/null +++ b/test/Core.Test/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetailsTests.cs @@ -0,0 +1,82 @@ +using System.Text.Json; +using Bit.Core.Models.Data.Organizations; +using Xunit; + +namespace Bit.Core.Test.Models.Data.Organizations; + +public class OrganizationIntegrationConfigurationDetailsTests +{ + [Fact] + public void MergedConfiguration_WithValidConfigAndIntegration_ReturnsMergedJson() + { + var config = new { config = "A new config value" }; + var integration = new { integration = "An integration value" }; + var expectedObj = new { integration = "An integration value", config = "A new config value" }; + var expected = JsonSerializer.Serialize(expectedObj); + + var sut = new OrganizationIntegrationConfigurationDetails(); + sut.Configuration = JsonSerializer.Serialize(config); + sut.IntegrationConfiguration = JsonSerializer.Serialize(integration); + + var result = sut.MergedConfiguration; + Assert.Equal(expected, result.ToJsonString()); + } + + [Fact] + public void MergedConfiguration_WithInvalidJsonConfigAndIntegration_ReturnsEmptyJson() + { + var expectedObj = new { }; + var expected = JsonSerializer.Serialize(expectedObj); + + var sut = new OrganizationIntegrationConfigurationDetails(); + sut.Configuration = "Not JSON"; + sut.IntegrationConfiguration = "Not JSON"; + + var result = sut.MergedConfiguration; + Assert.Equal(expected, result.ToJsonString()); + } + + [Fact] + public void MergedConfiguration_WithNullConfigAndIntegration_ReturnsEmptyJson() + { + var expectedObj = new { }; + var expected = JsonSerializer.Serialize(expectedObj); + + var sut = new OrganizationIntegrationConfigurationDetails(); + sut.Configuration = null; + sut.IntegrationConfiguration = null; + + var result = sut.MergedConfiguration; + Assert.Equal(expected, result.ToJsonString()); + } + + [Fact] + public void MergedConfiguration_WithValidIntegrationAndNullConfig_ReturnsIntegrationJson() + { + var integration = new { integration = "An integration value" }; + var expectedObj = new { integration = "An integration value" }; + var expected = JsonSerializer.Serialize(expectedObj); + + var sut = new OrganizationIntegrationConfigurationDetails(); + sut.Configuration = null; + sut.IntegrationConfiguration = JsonSerializer.Serialize(integration); + + var result = sut.MergedConfiguration; + Assert.Equal(expected, result.ToJsonString()); + } + + [Fact] + public void MergedConfiguration_WithValidConfigAndNullIntegration_ReturnsConfigJson() + { + var config = new { config = "A new config value" }; + var expectedObj = new { config = "A new config value" }; + var expected = JsonSerializer.Serialize(expectedObj); + + var sut = new OrganizationIntegrationConfigurationDetails(); + sut.Configuration = JsonSerializer.Serialize(config); + sut.IntegrationConfiguration = null; + + var result = sut.MergedConfiguration; + Assert.Equal(expected, result.ToJsonString()); + } +} From 38ae5ff885b2990f241e94b611934f2470766f5c Mon Sep 17 00:00:00 2001 From: Jimmy Vo Date: Thu, 3 Apr 2025 11:35:09 -0400 Subject: [PATCH 14/15] [PM-19588] Ensure custom users cannot delete or remove admins. (#5590) --- ...teManagedOrganizationUserAccountCommand.cs | 6 ++++ .../RemoveOrganizationUserCommand.cs | 8 ++++- ...agedOrganizationUserAccountCommandTests.cs | 32 +++++++++++++++++++ .../RemoveOrganizationUserCommandTests.cs | 22 +++++++++++++ 4 files changed, 67 insertions(+), 1 deletion(-) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteManagedOrganizationUserAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteManagedOrganizationUserAccountCommand.cs index 010f5de9bf..7b7d8003a3 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteManagedOrganizationUserAccountCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteManagedOrganizationUserAccountCommand.cs @@ -154,6 +154,12 @@ public class DeleteManagedOrganizationUserAccountCommand : IDeleteManagedOrganiz } } + if (orgUser.Type == OrganizationUserType.Admin && await _currentContext.OrganizationCustom(organizationId)) + { + throw new BadRequestException("Custom users can not delete admins."); + } + + if (!managementStatus.TryGetValue(orgUser.Id, out var isManaged) || !isManaged) { throw new BadRequestException("Member is not managed by the organization."); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs index 9375a231ec..3568a2a2b9 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs @@ -25,7 +25,8 @@ public class RemoveOrganizationUserCommand : IRemoveOrganizationUserCommand public const string UserNotFoundErrorMessage = "User not found."; public const string UsersInvalidErrorMessage = "Users invalid."; public const string RemoveYourselfErrorMessage = "You cannot remove yourself."; - public const string RemoveOwnerByNonOwnerErrorMessage = "Only owners can delete other owners."; + public const string RemoveOwnerByNonOwnerErrorMessage = "Only owners can remove other owners."; + public const string RemoveAdminByCustomUserErrorMessage = "Custom users can not remove admins."; public const string RemoveLastConfirmedOwnerErrorMessage = "Organization must have at least one confirmed owner."; public const string RemoveClaimedAccountErrorMessage = "Cannot remove member accounts claimed by the organization. To offboard a member, revoke or delete the account."; @@ -153,6 +154,11 @@ public class RemoveOrganizationUserCommand : IRemoveOrganizationUserCommand } } + if (orgUser.Type == OrganizationUserType.Admin && await _currentContext.OrganizationCustom(orgUser.OrganizationId)) + { + throw new BadRequestException(RemoveAdminByCustomUserErrorMessage); + } + if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning) && deletingUserId.HasValue && eventSystemUser == null) { var managementStatus = await _getOrganizationUsersManagementStatusQuery.GetUsersOrganizationManagementStatusAsync(orgUser.OrganizationId, new[] { orgUser.Id }); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteManagedOrganizationUserAccountCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteManagedOrganizationUserAccountCommandTests.cs index b21ae5459f..f8f6bdb60d 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteManagedOrganizationUserAccountCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteManagedOrganizationUserAccountCommandTests.cs @@ -131,6 +131,38 @@ public class DeleteManagedOrganizationUserAccountCommandTests .LogOrganizationUserEventAsync(Arg.Any(), Arg.Any(), Arg.Any()); } + [Theory] + [BitAutoData] + public async Task DeleteUserAsync_WhenCustomUserDeletesAdmin_ThrowsException( + SutProvider sutProvider, User user, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Admin)] OrganizationUser organizationUser, + Guid deletingUserId) + { + // Arrange + organizationUser.UserId = user.Id; + + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + + sutProvider.GetDependency().GetByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .OrganizationCustom(organizationUser.OrganizationId) + .Returns(true); + + // Act + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.DeleteUserAsync(organizationUser.OrganizationId, organizationUser.Id, deletingUserId)); + + // Assert + Assert.Equal("Custom users can not delete admins.", exception.Message); + await sutProvider.GetDependency().Received(0).DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().Received(0) + .LogOrganizationUserEventAsync(Arg.Any(), Arg.Any(), Arg.Any()); + } + [Theory] [BitAutoData] public async Task DeleteUserAsync_DeletingOwnerWhenNotOwner_ThrowsException( diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommandTests.cs index 6ab8236b8e..a60850c5a9 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommandTests.cs @@ -171,6 +171,28 @@ public class RemoveOrganizationUserCommandTests Assert.Contains(RemoveOrganizationUserCommand.RemoveOwnerByNonOwnerErrorMessage, exception.Message); } + [Theory, BitAutoData] + public async Task RemoveUser_WhenCustomUserRemovesAdmin_ThrowsException( + [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser organizationUser, + [OrganizationUser(type: OrganizationUserType.Custom)] OrganizationUser deletingUser, + SutProvider sutProvider) + { + // Arrange + organizationUser.OrganizationId = deletingUser.OrganizationId; + + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + sutProvider.GetDependency() + .OrganizationCustom(organizationUser.OrganizationId) + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveUserAsync(organizationUser.OrganizationId, organizationUser.Id, deletingUser.UserId)); + Assert.Contains(RemoveOrganizationUserCommand.RemoveAdminByCustomUserErrorMessage, exception.Message); + } + [Theory, BitAutoData] public async Task RemoveUser_WithDeletingUserId_RemovingLastOwner_ThrowsException( [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, From 83e06c924100f4384e746d9a793d1486e60f8130 Mon Sep 17 00:00:00 2001 From: Jake Fink Date: Thu, 3 Apr 2025 11:57:51 -0400 Subject: [PATCH 15/15] [PM-19523] Filter expected webauthn keys for rotations by prf enabled (#5566) * filter expected webauthn keys for rotations by prf enabled * fix and add tests * format --- .../WebAuthnLoginKeyRotationValidator.cs | 16 +++--- .../WebauthnLoginKeyRotationValidatorTests.cs | 56 +++++++++++++++++++ 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/src/Api/KeyManagement/Validators/WebAuthnLoginKeyRotationValidator.cs b/src/Api/KeyManagement/Validators/WebAuthnLoginKeyRotationValidator.cs index 1706aebd78..9c7efe0fbe 100644 --- a/src/Api/KeyManagement/Validators/WebAuthnLoginKeyRotationValidator.cs +++ b/src/Api/KeyManagement/Validators/WebAuthnLoginKeyRotationValidator.cs @@ -17,20 +17,20 @@ public class WebAuthnLoginKeyRotationValidator : IRotationValidator> ValidateAsync(User user, IEnumerable keysToRotate) { - // 2024-06: Remove after 3 releases, for backward compatibility - if (keysToRotate == null) - { - return new List(); - } - var result = new List(); var existing = await _webAuthnCredentialRepository.GetManyByUserIdAsync(user.Id); - if (existing == null || !existing.Any()) + if (existing == null) { return result; } - foreach (var ea in existing) + var validCredentials = existing.Where(credential => credential.SupportsPrf); + if (!validCredentials.Any()) + { + return result; + } + + foreach (var ea in validCredentials) { var keyToRotate = keysToRotate.FirstOrDefault(c => c.Id == ea.Id); if (keyToRotate == null) diff --git a/test/Api.Test/KeyManagement/Validators/WebauthnLoginKeyRotationValidatorTests.cs b/test/Api.Test/KeyManagement/Validators/WebauthnLoginKeyRotationValidatorTests.cs index de661497e4..93652735ef 100644 --- a/test/Api.Test/KeyManagement/Validators/WebauthnLoginKeyRotationValidatorTests.cs +++ b/test/Api.Test/KeyManagement/Validators/WebauthnLoginKeyRotationValidatorTests.cs @@ -14,6 +14,59 @@ namespace Bit.Api.Test.KeyManagement.Validators; [SutProviderCustomize] public class WebAuthnLoginKeyRotationValidatorTests { + [Theory] + [BitAutoData] + public async Task ValidateAsync_Succeeds_ReturnsValidCredentials( + SutProvider sutProvider, User user, + IEnumerable webauthnRotateCredentialData) + { + var guid = Guid.NewGuid(); + + var webauthnKeysToRotate = webauthnRotateCredentialData.Select(e => new WebAuthnLoginRotateKeyRequestModel + { + Id = guid, + EncryptedPublicKey = e.EncryptedPublicKey, + EncryptedUserKey = e.EncryptedUserKey + }).ToList(); + + var data = new WebAuthnCredential + { + Id = guid, + SupportsPrf = true, + EncryptedPublicKey = "TestKey", + EncryptedUserKey = "Test" + }; + sutProvider.GetDependency().GetManyByUserIdAsync(user.Id) + .Returns(new List { data }); + + var result = await sutProvider.Sut.ValidateAsync(user, webauthnKeysToRotate); + Assert.Single(result); + Assert.Equal(guid, result.First().Id); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_DoesNotSupportPRF_Ignores( + SutProvider sutProvider, User user, + IEnumerable webauthnRotateCredentialData) + { + var guid = Guid.NewGuid(); + var webauthnKeysToRotate = webauthnRotateCredentialData.Select(e => new WebAuthnLoginRotateKeyRequestModel + { + Id = guid, + EncryptedUserKey = e.EncryptedUserKey, + EncryptedPublicKey = e.EncryptedPublicKey, + }).ToList(); + + var data = new WebAuthnCredential { Id = guid, EncryptedUserKey = "Test", EncryptedPublicKey = "TestKey" }; + + sutProvider.GetDependency().GetManyByUserIdAsync(user.Id) + .Returns(new List { data }); + + var result = await sutProvider.Sut.ValidateAsync(user, webauthnKeysToRotate); + Assert.Empty(result); + } + [Theory] [BitAutoData] public async Task ValidateAsync_WrongWebAuthnKeys_Throws( @@ -30,6 +83,7 @@ public class WebAuthnLoginKeyRotationValidatorTests var data = new WebAuthnCredential { Id = Guid.Parse("00000000-0000-0000-0000-000000000002"), + SupportsPrf = true, EncryptedPublicKey = "TestKey", EncryptedUserKey = "Test" }; @@ -55,6 +109,7 @@ public class WebAuthnLoginKeyRotationValidatorTests var data = new WebAuthnCredential { Id = guid, + SupportsPrf = true, EncryptedPublicKey = "TestKey", EncryptedUserKey = "Test" }; @@ -81,6 +136,7 @@ public class WebAuthnLoginKeyRotationValidatorTests var data = new WebAuthnCredential { Id = guid, + SupportsPrf = true, EncryptedPublicKey = "TestKey", EncryptedUserKey = "Test" };