From 8ecd9c5fb329a8949043173766fad3398cb0bd44 Mon Sep 17 00:00:00 2001 From: Brandon Treston Date: Thu, 1 May 2025 10:07:19 -0400 Subject: [PATCH 01/11] [PM-19332] Create InitPendingOrganizationCommand (#5584) * wip * implement CommandResult * remove auth handler * fix import * remove method from OrganizationService * cleanup, add tests * clean up * fix auth in tests * clean up comments * clean up comments * clean up test --- .../OrganizationUsersController.cs | 7 +- .../InitPendingOrganizationCommand.cs | 128 +++++++++++++ .../IInitPendingOrganizationCommand.cs | 13 ++ .../Services/IOrganizationService.cs | 7 - .../Implementations/OrganizationService.cs | 78 +------- ...OrganizationServiceCollectionExtensions.cs | 1 + .../InitPendingOrganizationCommandTests.cs | 169 ++++++++++++++++++ 7 files changed, 317 insertions(+), 86 deletions(-) create mode 100644 src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs create mode 100644 src/Core/AdminConsole/OrganizationFeatures/Organizations/Interfaces/IInitPendingOrganizationCommand.cs create mode 100644 test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommandTests.cs diff --git a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs index 5a714943f0..e21dd3de49 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs @@ -63,6 +63,7 @@ public class OrganizationUsersController : Controller private readonly IPricingClient _pricingClient; private readonly IConfirmOrganizationUserCommand _confirmOrganizationUserCommand; private readonly IRestoreOrganizationUserCommand _restoreOrganizationUserCommand; + private readonly IInitPendingOrganizationCommand _initPendingOrganizationCommand; public OrganizationUsersController( IOrganizationRepository organizationRepository, @@ -89,7 +90,8 @@ public class OrganizationUsersController : Controller IFeatureService featureService, IPricingClient pricingClient, IConfirmOrganizationUserCommand confirmOrganizationUserCommand, - IRestoreOrganizationUserCommand restoreOrganizationUserCommand) + IRestoreOrganizationUserCommand restoreOrganizationUserCommand, + IInitPendingOrganizationCommand initPendingOrganizationCommand) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -116,6 +118,7 @@ public class OrganizationUsersController : Controller _pricingClient = pricingClient; _confirmOrganizationUserCommand = confirmOrganizationUserCommand; _restoreOrganizationUserCommand = restoreOrganizationUserCommand; + _initPendingOrganizationCommand = initPendingOrganizationCommand; } [HttpGet("{id}")] @@ -313,7 +316,7 @@ public class OrganizationUsersController : Controller throw new UnauthorizedAccessException(); } - await _organizationService.InitPendingOrganization(user, orgId, organizationUserId, model.Keys.PublicKey, model.Keys.EncryptedPrivateKey, model.CollectionName, model.Token); + await _initPendingOrganizationCommand.InitPendingOrganizationAsync(user, orgId, organizationUserId, model.Keys.PublicKey, model.Keys.EncryptedPrivateKey, model.CollectionName, model.Token); await _acceptOrgUserCommand.AcceptOrgUserByEmailTokenAsync(organizationUserId, user, model.Token, _userService); await _confirmOrganizationUserCommand.ConfirmUserAsync(orgId, organizationUserId, model.Key, user.Id); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs new file mode 100644 index 0000000000..3e060c66a5 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs @@ -0,0 +1,128 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Models.Data; +using Bit.Core.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Tokens; +using Microsoft.AspNetCore.DataProtection; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; + +public class InitPendingOrganizationCommand : IInitPendingOrganizationCommand +{ + + private readonly IOrganizationService _organizationService; + private readonly ICollectionRepository _collectionRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; + private readonly IDataProtector _dataProtector; + private readonly IGlobalSettings _globalSettings; + private readonly IPolicyService _policyService; + private readonly IOrganizationUserRepository _organizationUserRepository; + + public InitPendingOrganizationCommand( + IOrganizationService organizationService, + ICollectionRepository collectionRepository, + IOrganizationRepository organizationRepository, + IDataProtectorTokenFactory orgUserInviteTokenDataFactory, + IDataProtectionProvider dataProtectionProvider, + IGlobalSettings globalSettings, + IPolicyService policyService, + IOrganizationUserRepository organizationUserRepository + ) + { + _organizationService = organizationService; + _collectionRepository = collectionRepository; + _organizationRepository = organizationRepository; + _orgUserInviteTokenDataFactory = orgUserInviteTokenDataFactory; + _dataProtector = dataProtectionProvider.CreateProtector(OrgUserInviteTokenable.DataProtectorPurpose); + _globalSettings = globalSettings; + _policyService = policyService; + _organizationUserRepository = organizationUserRepository; + } + + public async Task InitPendingOrganizationAsync(User user, Guid organizationId, Guid organizationUserId, string publicKey, string privateKey, string collectionName, string emailToken) + { + await ValidateSignUpPoliciesAsync(user.Id); + + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null) + { + throw new BadRequestException("User invalid."); + } + + var tokenValid = ValidateInviteToken(orgUser, user, emailToken); + + if (!tokenValid) + { + throw new BadRequestException("Invalid token"); + } + + var org = await _organizationRepository.GetByIdAsync(organizationId); + + if (org.Enabled) + { + throw new BadRequestException("Organization is already enabled."); + } + + if (org.Status != OrganizationStatusType.Pending) + { + throw new BadRequestException("Organization is not on a Pending status."); + } + + if (!string.IsNullOrEmpty(org.PublicKey)) + { + throw new BadRequestException("Organization already has a Public Key."); + } + + if (!string.IsNullOrEmpty(org.PrivateKey)) + { + throw new BadRequestException("Organization already has a Private Key."); + } + + org.Enabled = true; + org.Status = OrganizationStatusType.Created; + org.PublicKey = publicKey; + org.PrivateKey = privateKey; + + await _organizationService.UpdateAsync(org); + + if (!string.IsNullOrWhiteSpace(collectionName)) + { + // give the owner Can Manage access over the default collection + List defaultOwnerAccess = + [new CollectionAccessSelection { Id = orgUser.Id, HidePasswords = false, ReadOnly = false, Manage = true }]; + + var defaultCollection = new Collection + { + Name = collectionName, + OrganizationId = org.Id + }; + await _collectionRepository.CreateAsync(defaultCollection, null, defaultOwnerAccess); + } + } + + private async Task ValidateSignUpPoliciesAsync(Guid ownerId) + { + var anySingleOrgPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(ownerId, PolicyType.SingleOrg); + if (anySingleOrgPolicies) + { + throw new BadRequestException("You may not create an organization. You belong to an organization " + + "which has a policy that prohibits you from being a member of any other organization."); + } + } + + private bool ValidateInviteToken(OrganizationUser orgUser, User user, string emailToken) + { + var tokenValid = OrgUserInviteTokenable.ValidateOrgUserInviteStringToken( + _orgUserInviteTokenDataFactory, emailToken, orgUser); + + return tokenValid; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/Interfaces/IInitPendingOrganizationCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Interfaces/IInitPendingOrganizationCommand.cs new file mode 100644 index 0000000000..273182664e --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Interfaces/IInitPendingOrganizationCommand.cs @@ -0,0 +1,13 @@ +using Bit.Core.Entities; +namespace Bit.Core.OrganizationFeatures.OrganizationUsers.Interfaces; + +public interface IInitPendingOrganizationCommand +{ + /// + /// Update an Organization entry by setting the public/private keys, set it as 'Enabled' and move the Status from 'Pending' to 'Created'. + /// + /// + /// This method must target a disabled Organization that has null keys and status as 'Pending'. + /// + Task InitPendingOrganizationAsync(User user, Guid organizationId, Guid organizationUserId, string publicKey, string privateKey, string collectionName, string emailToken); +} diff --git a/src/Core/AdminConsole/Services/IOrganizationService.cs b/src/Core/AdminConsole/Services/IOrganizationService.cs index 228c2b522c..9c9e311a02 100644 --- a/src/Core/AdminConsole/Services/IOrganizationService.cs +++ b/src/Core/AdminConsole/Services/IOrganizationService.cs @@ -48,13 +48,6 @@ public interface IOrganizationService Task>> RevokeUsersAsync(Guid organizationId, IEnumerable organizationUserIds, Guid? revokingUserId); Task CreatePendingOrganization(Organization organization, string ownerEmail, ClaimsPrincipal user, IUserService userService, bool salesAssistedTrialStarted); - /// - /// Update an Organization entry by setting the public/private keys, set it as 'Enabled' and move the Status from 'Pending' to 'Created'. - /// - /// - /// This method must target a disabled Organization that has null keys and status as 'Pending'. - /// - Task InitPendingOrganization(User user, Guid organizationId, Guid organizationUserId, string publicKey, string privateKey, string collectionName, string emailToken); Task ReplaceAndUpdateCacheAsync(Organization org, EventType? orgEvent = null); void ValidatePasswordManagerPlan(Models.StaticStore.Plan plan, OrganizationUpgrade upgrade); diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index b31b43406e..532aebf5e0 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -13,7 +13,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Enums; -using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Repositories; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Billing.Constants; @@ -31,12 +30,10 @@ using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Settings; -using Bit.Core.Tokens; using Bit.Core.Tools.Enums; using Bit.Core.Tools.Models.Business; using Bit.Core.Tools.Services; using Bit.Core.Utilities; -using Microsoft.AspNetCore.DataProtection; using Microsoft.Extensions.Logging; using Stripe; using OrganizationUserInvite = Bit.Core.Models.Business.OrganizationUserInvite; @@ -77,8 +74,6 @@ public class OrganizationService : IOrganizationService private readonly IPricingClient _pricingClient; private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly ISendOrganizationInvitesCommand _sendOrganizationInvitesCommand; - private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; - private readonly IDataProtector _dataProtector; public OrganizationService( IOrganizationRepository organizationRepository, @@ -112,9 +107,7 @@ public class OrganizationService : IOrganizationService IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, IPricingClient pricingClient, IPolicyRequirementQuery policyRequirementQuery, - ISendOrganizationInvitesCommand sendOrganizationInvitesCommand, - IDataProtectorTokenFactory orgUserInviteTokenDataFactory, - IDataProtectionProvider dataProtectionProvider + ISendOrganizationInvitesCommand sendOrganizationInvitesCommand ) { _organizationRepository = organizationRepository; @@ -149,8 +142,6 @@ public class OrganizationService : IOrganizationService _pricingClient = pricingClient; _policyRequirementQuery = policyRequirementQuery; _sendOrganizationInvitesCommand = sendOrganizationInvitesCommand; - _orgUserInviteTokenDataFactory = orgUserInviteTokenDataFactory; - _dataProtector = dataProtectionProvider.CreateProtector(OrgUserInviteTokenable.DataProtectorPurpose); } public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, @@ -1921,71 +1912,4 @@ public class OrganizationService : IOrganizationService SalesAssistedTrialStarted = salesAssistedTrialStarted, }); } - - public async Task InitPendingOrganization(User user, Guid organizationId, Guid organizationUserId, string publicKey, string privateKey, string collectionName, string emailToken) - { - await ValidateSignUpPoliciesAsync(user.Id); - - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null) - { - throw new BadRequestException("User invalid."); - } - - // TODO: PM-4142 - remove old token validation logic once 3 releases of backwards compatibility are complete - var newTokenValid = OrgUserInviteTokenable.ValidateOrgUserInviteStringToken( - _orgUserInviteTokenDataFactory, emailToken, orgUser); - - var tokenValid = newTokenValid || - CoreHelpers.UserInviteTokenIsValid(_dataProtector, emailToken, user.Email, orgUser.Id, - _globalSettings); - - if (!tokenValid) - { - throw new BadRequestException("Invalid token."); - } - - var org = await GetOrgById(organizationId); - - if (org.Enabled) - { - throw new BadRequestException("Organization is already enabled."); - } - - if (org.Status != OrganizationStatusType.Pending) - { - throw new BadRequestException("Organization is not on a Pending status."); - } - - if (!string.IsNullOrEmpty(org.PublicKey)) - { - throw new BadRequestException("Organization already has a Public Key."); - } - - if (!string.IsNullOrEmpty(org.PrivateKey)) - { - throw new BadRequestException("Organization already has a Private Key."); - } - - org.Enabled = true; - org.Status = OrganizationStatusType.Created; - org.PublicKey = publicKey; - org.PrivateKey = privateKey; - - await UpdateAsync(org); - - if (!string.IsNullOrWhiteSpace(collectionName)) - { - // give the owner Can Manage access over the default collection - List defaultOwnerAccess = - [new CollectionAccessSelection { Id = organizationUserId, HidePasswords = false, ReadOnly = false, Manage = true }]; - - var defaultCollection = new Collection - { - Name = collectionName, - OrganizationId = org.Id - }; - await _collectionRepository.CreateAsync(defaultCollection, null, defaultOwnerAccess); - } - } } diff --git a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs index 164710d522..b016e329bf 100644 --- a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs +++ b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs @@ -193,6 +193,7 @@ public static class OrganizationServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); } // TODO: move to OrganizationSubscriptionServiceCollectionExtensions when OrganizationUser methods are moved out of diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommandTests.cs new file mode 100644 index 0000000000..83ea4798db --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommandTests.cs @@ -0,0 +1,169 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; +using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.Models.Data; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Tokens; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Bit.Test.Common.Fakes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Organizations; + +[SutProviderCustomize] +public class InitPendingOrganizationCommandTests +{ + + private readonly IOrgUserInviteTokenableFactory _orgUserInviteTokenableFactory = Substitute.For(); + private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory = new FakeDataProtectorTokenFactory(); + + [Theory, BitAutoData] + public async Task Init_Organization_Success(User user, Guid orgId, Guid orgUserId, string publicKey, + string privateKey, SutProvider sutProvider, Organization org, OrganizationUser orgUser) + { + var token = CreateToken(orgUser, orgUserId, sutProvider); + + org.PrivateKey = null; + org.PublicKey = null; + + var organizationRepository = sutProvider.GetDependency(); + organizationRepository.GetByIdAsync(orgId).Returns(org); + + var organizationServcie = sutProvider.GetDependency(); + var collectionRepository = sutProvider.GetDependency(); + + await sutProvider.Sut.InitPendingOrganizationAsync(user, orgId, orgUserId, publicKey, privateKey, "", token); + + await organizationRepository.Received().GetByIdAsync(orgId); + await organizationServcie.Received().UpdateAsync(org); + await collectionRepository.DidNotReceiveWithAnyArgs().CreateAsync(default); + + } + + [Theory, BitAutoData] + public async Task Init_Organization_With_CollectionName_Success(User user, Guid orgId, Guid orgUserId, string publicKey, + string privateKey, SutProvider sutProvider, Organization org, string collectionName, OrganizationUser orgUser) + { + var token = CreateToken(orgUser, orgUserId, sutProvider); + + org.PrivateKey = null; + org.PublicKey = null; + org.Id = orgId; + + var organizationRepository = sutProvider.GetDependency(); + organizationRepository.GetByIdAsync(orgId).Returns(org); + + var organizationServcie = sutProvider.GetDependency(); + var collectionRepository = sutProvider.GetDependency(); + + await sutProvider.Sut.InitPendingOrganizationAsync(user, orgId, orgUserId, publicKey, privateKey, collectionName, token); + + await organizationRepository.Received().GetByIdAsync(orgId); + await organizationServcie.Received().UpdateAsync(org); + + await collectionRepository.Received().CreateAsync( + Arg.Any(), + Arg.Is>(l => l == null), + Arg.Is>(l => l.Any(i => i.Manage == true))); + + } + + [Theory, BitAutoData] + public async Task Init_Organization_When_Organization_Is_Enabled(User user, Guid orgId, Guid orgUserId, string publicKey, + string privateKey, SutProvider sutProvider, Organization org, OrganizationUser orgUser) + + { + var token = CreateToken(orgUser, orgUserId, sutProvider); + + org.Enabled = true; + + var organizationRepository = sutProvider.GetDependency(); + organizationRepository.GetByIdAsync(orgId).Returns(org); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.InitPendingOrganizationAsync(user, orgId, orgUserId, publicKey, privateKey, "", token)); + + Assert.Equal("Organization is already enabled.", exception.Message); + + } + + [Theory, BitAutoData] + public async Task Init_Organization_When_Organization_Is_Not_Pending(User user, Guid orgId, Guid orgUserId, string publicKey, + string privateKey, SutProvider sutProvider, Organization org, OrganizationUser orgUser) + + { + + + var token = CreateToken(orgUser, orgUserId, sutProvider); + + org.Status = Enums.OrganizationStatusType.Created; + + var organizationRepository = sutProvider.GetDependency(); + organizationRepository.GetByIdAsync(orgId).Returns(org); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.InitPendingOrganizationAsync(user, orgId, orgUserId, publicKey, privateKey, "", token)); + + Assert.Equal("Organization is not on a Pending status.", exception.Message); + + } + + [Theory, BitAutoData] + public async Task Init_Organization_When_Organization_Has_Public_Key(User user, Guid orgId, Guid orgUserId, string publicKey, + string privateKey, SutProvider sutProvider, Organization org, OrganizationUser orgUser) + + { + var token = CreateToken(orgUser, orgUserId, sutProvider); + + org.PublicKey = publicKey; + + var organizationRepository = sutProvider.GetDependency(); + organizationRepository.GetByIdAsync(orgId).Returns(org); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.InitPendingOrganizationAsync(user, orgId, orgUserId, publicKey, privateKey, "", token)); + + Assert.Equal("Organization already has a Public Key.", exception.Message); + + } + + [Theory, BitAutoData] + public async Task Init_Organization_When_Organization_Has_Private_Key(User user, Guid orgId, Guid orgUserId, string publicKey, + string privateKey, SutProvider sutProvider, Organization org, OrganizationUser orgUser) + + { + + var token = CreateToken(orgUser, orgUserId, sutProvider); + + org.PublicKey = null; + org.PrivateKey = privateKey; + org.Enabled = false; + + var organizationRepository = sutProvider.GetDependency(); + organizationRepository.GetByIdAsync(orgId).Returns(org); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.InitPendingOrganizationAsync(user, orgId, orgUserId, publicKey, privateKey, "", token)); + + Assert.Equal("Organization already has a Private Key.", exception.Message); + + } + + public string CreateToken(OrganizationUser orgUser, Guid orgUserId, SutProvider sutProvider) + { + sutProvider.SetDependency(_orgUserInviteTokenDataFactory, "orgUserInviteTokenDataFactory"); + sutProvider.Create(); + + _orgUserInviteTokenableFactory.CreateToken(orgUser).Returns(new OrgUserInviteTokenable(orgUser) + { + ExpirationDate = DateTime.UtcNow.Add(TimeSpan.FromDays(5)) + }); + + var orgUserInviteTokenable = _orgUserInviteTokenableFactory.CreateToken(orgUser); + var protectedToken = _orgUserInviteTokenDataFactory.Protect(orgUserInviteTokenable); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(orgUser); + + return protectedToken; + } +} From dc5db5673f4ca5d70d3bac0ca0b5900f699a808f Mon Sep 17 00:00:00 2001 From: cyprain-okeke <108260115+cyprain-okeke@users.noreply.github.com> Date: Thu, 1 May 2025 16:35:51 +0100 Subject: [PATCH 02/11] [PM-17775] (#5699) * Changes to allow admin to send F4E sponsorship * Fix the failing unit tests * Fix the failing test Signed-off-by: Cy Okeke * Merge Changes with pm-17777 Signed-off-by: Cy Okeke * Add changes for autoscale Signed-off-by: Cy Okeke * Return the right error response Signed-off-by: Cy Okeke * Resolve the failing unit test Signed-off-by: Cy Okeke --------- Signed-off-by: Cy Okeke --- .../OrganizationSponsorshipsController.cs | 29 ++++++- ...nizationSponsorshipInvitesResponseModel.cs | 37 +++++++++ .../CreateSponsorshipCommand.cs | 25 ++++-- ...OrganizationSponsorshipsControllerTests.cs | 78 +++++++++++++++++++ 4 files changed, 161 insertions(+), 8 deletions(-) create mode 100644 src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipInvitesResponseModel.cs diff --git a/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs b/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs index 67cd691a34..9a328081fe 100644 --- a/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs +++ b/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs @@ -1,4 +1,5 @@ using Bit.Api.Models.Request.Organizations; +using Bit.Api.Models.Response; using Bit.Api.Models.Response.Organizations; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationConnections.Interfaces; @@ -8,6 +9,7 @@ using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.Models.Api.Request.OrganizationSponsorships; using Bit.Core.Models.Api.Response.OrganizationSponsorships; +using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -105,7 +107,10 @@ public class OrganizationSponsorshipsController : Controller model.FriendlyName, model.IsAdminInitiated.GetValueOrDefault(), model.Notes); - await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name); + if (sponsorship.OfferedToEmail != null) + { + await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name); + } } [Authorize("Application")] @@ -246,5 +251,27 @@ public class OrganizationSponsorshipsController : Controller return new OrganizationSponsorshipSyncStatusResponseModel(lastSyncDate); } + [Authorize("Application")] + [HttpGet("{sponsoringOrgId}/sponsored")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task> GetSponsoredOrganizations(Guid sponsoringOrgId) + { + var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); + if (sponsoringOrg == null) + { + throw new NotFoundException(); + } + var organization = _currentContext.Organizations.First(x => x.Id == sponsoringOrg.Id); + if (!await _currentContext.OrganizationOwner(sponsoringOrg.Id) && !await _currentContext.OrganizationAdmin(sponsoringOrg.Id) && !organization.Permissions.ManageUsers) + { + throw new UnauthorizedAccessException(); + } + + var sponsorships = await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrgId); + return new ListResponseModel(sponsorships.Select(s => + new OrganizationSponsorshipInvitesResponseModel(new OrganizationSponsorshipData(s)))); + + } + private Task CurrentUser => _userService.GetUserByIdAsync(_currentContext.UserId.Value); } diff --git a/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipInvitesResponseModel.cs b/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipInvitesResponseModel.cs new file mode 100644 index 0000000000..b75144c81b --- /dev/null +++ b/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipInvitesResponseModel.cs @@ -0,0 +1,37 @@ +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; + +namespace Bit.Core.Models.Api.Response.OrganizationSponsorships; + +public class OrganizationSponsorshipInvitesResponseModel : ResponseModel +{ + public OrganizationSponsorshipInvitesResponseModel(OrganizationSponsorshipData sponsorshipData, string obj = "organizationSponsorship") : base(obj) + { + if (sponsorshipData == null) + { + throw new ArgumentNullException(nameof(sponsorshipData)); + } + + SponsoringOrganizationUserId = sponsorshipData.SponsoringOrganizationUserId; + FriendlyName = sponsorshipData.FriendlyName; + OfferedToEmail = sponsorshipData.OfferedToEmail; + PlanSponsorshipType = sponsorshipData.PlanSponsorshipType; + LastSyncDate = sponsorshipData.LastSyncDate; + ValidUntil = sponsorshipData.ValidUntil; + ToDelete = sponsorshipData.ToDelete; + IsAdminInitiated = sponsorshipData.IsAdminInitiated; + Notes = sponsorshipData.Notes; + CloudSponsorshipRemoved = sponsorshipData.CloudSponsorshipRemoved; + } + + public Guid SponsoringOrganizationUserId { get; set; } + public string FriendlyName { get; set; } + public string OfferedToEmail { get; set; } + public PlanSponsorshipType PlanSponsorshipType { get; set; } + public DateTime? LastSyncDate { get; set; } + public DateTime? ValidUntil { get; set; } + public bool ToDelete { get; set; } + public bool IsAdminInitiated { get; set; } + public string Notes { get; set; } + public bool CloudSponsorshipRemoved { get; set; } +} diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs index f81a1d9e84..3b74baf6f9 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs @@ -47,11 +47,12 @@ public class CreateSponsorshipCommand( throw new BadRequestException("Only confirmed users can sponsor other organizations."); } - var existingOrgSponsorship = await organizationSponsorshipRepository - .GetBySponsoringOrganizationUserIdAsync(sponsoringMember.Id); - if (existingOrgSponsorship?.SponsoredOrganizationId != null) + var sponsorships = + await organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id); + var existingSponsorship = sponsorships.FirstOrDefault(s => s.FriendlyName == friendlyName); + if (existingSponsorship != null) { - throw new BadRequestException("Can only sponsor one organization per Organization User."); + return existingSponsorship; } if (isAdminInitiated) @@ -70,10 +71,20 @@ public class CreateSponsorshipCommand( Notes = notes }; - if (existingOrgSponsorship != null) + if (!isAdminInitiated) { - // Replace existing invalid offer with our new sponsorship offer - sponsorship.Id = existingOrgSponsorship.Id; + var existingOrgSponsorship = await organizationSponsorshipRepository + .GetBySponsoringOrganizationUserIdAsync(sponsoringMember.Id); + if (existingOrgSponsorship?.SponsoredOrganizationId != null) + { + throw new BadRequestException("Can only sponsor one organization per Organization User."); + } + + if (existingOrgSponsorship != null) + { + // Replace existing invalid offer with our new sponsorship offer + sponsorship.Id = existingOrgSponsorship.Id; + } } if (isAdminInitiated && sponsoringOrganization.Seats.HasValue) diff --git a/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs index 377bc9c2c8..f6158b9e3f 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs @@ -6,6 +6,7 @@ using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models.Data; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -13,6 +14,7 @@ using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; +using NSubstitute.ReturnsExtensions; using Xunit; namespace Bit.Api.Test.Billing.Controllers; @@ -146,4 +148,80 @@ public class OrganizationSponsorshipsControllerTests .DidNotReceiveWithAnyArgs() .RemoveSponsorshipAsync(default); } + + [Theory] + [BitAutoData] + public async Task GetSponsoredOrganizations_OrganizationNotFound_ThrowsNotFound( + Guid sponsoringOrgId, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(sponsoringOrgId).ReturnsNull(); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetSponsoredOrganizations(sponsoringOrgId)); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetManyBySponsoringOrganizationAsync(default); + } + + [Theory] + [BitAutoData] + public async Task GetSponsoredOrganizations_NotOrganizationOwner_ThrowsNotFound( + Organization sponsoringOrg, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency().OrganizationOwner(sponsoringOrg.Id).Returns(false); + sutProvider.GetDependency().OrganizationAdmin(sponsoringOrg.Id).Returns(false); + + // Create a CurrentContextOrganization with ManageUsers set to false + var currentContextOrg = new CurrentContextOrganization + { + Id = sponsoringOrg.Id, + Permissions = new Permissions { ManageUsers = false } + }; + sutProvider.GetDependency().Organizations.Returns(new List { currentContextOrg }); + + // Act & Assert + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetSponsoredOrganizations(sponsoringOrg.Id)); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetManyBySponsoringOrganizationAsync(default); + } + + [Theory] + [BitAutoData] + public async Task GetSponsoredOrganizations_Success_ReturnsSponsorships( + Organization sponsoringOrg, + List sponsorships, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency().OrganizationOwner(sponsoringOrg.Id).Returns(true); + sutProvider.GetDependency().OrganizationAdmin(sponsoringOrg.Id).Returns(false); + + // Create a CurrentContextOrganization from the sponsoringOrg + var currentContextOrg = new CurrentContextOrganization + { + Id = sponsoringOrg.Id, + Permissions = new Permissions { ManageUsers = true } + }; + sutProvider.GetDependency().Organizations.Returns(new List { currentContextOrg }); + + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(sponsoringOrg.Id).Returns(sponsorships); + + // Act + var result = await sutProvider.Sut.GetSponsoredOrganizations(sponsoringOrg.Id); + + // Assert + Assert.Equal(sponsorships.Count, result.Data.Count()); + await sutProvider.GetDependency().Received(1) + .GetManyBySponsoringOrganizationAsync(sponsoringOrg.Id); + } } From e77acbc5ad419b3641eab6ba381f4b3d55105e10 Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Thu, 1 May 2025 12:12:45 -0400 Subject: [PATCH 03/11] [PM-19956] [PM-18795] Require provider payment method during setup behind FF (#5752) * Require provider payment method during setup behind FF * Fix failing test * Run dotnet format * Rui's feedback --- .../AdminConsole/Services/ProviderService.cs | 18 +- .../Billing/ProviderBillingService.cs | 131 +++- .../Services/ProviderServiceTests.cs | 85 ++- .../Billing/ProviderBillingServiceTests.cs | 605 +++++++++++++++++- .../Controllers/ProvidersController.cs | 26 +- .../Providers/ProviderSetupRequestModel.cs | 3 + .../AdminConsole/Services/IProviderService.cs | 4 +- .../NoopProviderService.cs | 3 +- .../Services/IProviderBillingService.cs | 4 +- src/Core/Constants.cs | 1 + 10 files changed, 848 insertions(+), 32 deletions(-) diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index fff6b5271d..2fc44937a7 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -8,6 +8,7 @@ using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Context; @@ -82,7 +83,7 @@ public class ProviderService : IProviderService _pricingClient = pricingClient; } - public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null) + public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) { var owner = await _userService.GetUserByIdAsync(ownerUserId); if (owner == null) @@ -111,7 +112,20 @@ public class ProviderService : IProviderService { throw new BadRequestException("Both address and postal code are required to set up your provider."); } - var customer = await _providerBillingService.SetupCustomer(provider, taxInfo); + + var requireProviderPaymentMethodDuringSetup = + _featureService.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup); + + if (requireProviderPaymentMethodDuringSetup && tokenizedPaymentSource is not + { + Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal, + Token: not null and not "" + }) + { + throw new BadRequestException("A payment method is required to set up your provider."); + } + + var customer = await _providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); provider.GatewayCustomerId = customer.Id; var subscription = await _providerBillingService.SetupSubscription(provider); provider.GatewaySubscriptionId = subscription.Id; diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index bdfff079cf..f049d6c8df 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -6,9 +6,11 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing; +using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Repositories; @@ -21,14 +23,20 @@ using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; +using Braintree; using CsvHelper; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; +using static Bit.Core.Billing.Utilities; +using Customer = Stripe.Customer; +using Subscription = Stripe.Subscription; + namespace Bit.Commercial.Core.Billing; public class ProviderBillingService( + IBraintreeGateway braintreeGateway, IEventService eventService, IFeatureService featureService, IGlobalSettings globalSettings, @@ -39,6 +47,7 @@ public class ProviderBillingService( IProviderOrganizationRepository providerOrganizationRepository, IProviderPlanRepository providerPlanRepository, IProviderUserRepository providerUserRepository, + ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, ITaxService taxService, @@ -463,7 +472,8 @@ public class ProviderBillingService( public async Task SetupCustomer( Provider provider, - TaxInfo taxInfo) + TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource = null) { if (taxInfo is not { @@ -532,13 +542,97 @@ public class ProviderBillingService( options.Coupon = provider.DiscountId; } + var requireProviderPaymentMethodDuringSetup = + featureService.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup); + + var braintreeCustomerId = ""; + + if (requireProviderPaymentMethodDuringSetup) + { + if (tokenizedPaymentSource is not + { + Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal, + Token: not null and not "" + }) + { + logger.LogError("Cannot create customer for provider ({ProviderID}) without a payment method", provider.Id); + throw new BillingException(); + } + + var (type, token) = tokenizedPaymentSource; + + // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault + switch (type) + { + case PaymentMethodType.BankAccount: + { + var setupIntent = + (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = token })) + .FirstOrDefault(); + + if (setupIntent == null) + { + logger.LogError("Cannot create customer for provider ({ProviderID}) without a setup intent for their bank account", provider.Id); + throw new BillingException(); + } + + await setupIntentCache.Set(provider.Id, setupIntent.Id); + break; + } + case PaymentMethodType.Card: + { + options.PaymentMethod = token; + options.InvoiceSettings.DefaultPaymentMethod = token; + break; + } + case PaymentMethodType.PayPal: + { + braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(provider, token); + options.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId; + break; + } + } + } + try { return await stripeAdapter.CustomerCreateAsync(options); } - catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.TaxIdInvalid) + catch (StripeException stripeException) when (stripeException.StripeError?.Code == + StripeConstants.ErrorCodes.TaxIdInvalid) { - throw new BadRequestException("Your tax ID wasn't recognized for your selected country. Please ensure your country and tax ID are valid."); + await Revert(); + throw new BadRequestException( + "Your tax ID wasn't recognized for your selected country. Please ensure your country and tax ID are valid."); + } + catch + { + await Revert(); + throw; + } + + async Task Revert() + { + if (requireProviderPaymentMethodDuringSetup && tokenizedPaymentSource != null) + { + // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault + switch (tokenizedPaymentSource.Type) + { + case PaymentMethodType.BankAccount: + { + var setupIntentId = await setupIntentCache.Get(provider.Id); + await stripeAdapter.SetupIntentCancel(setupIntentId, + new SetupIntentCancelOptions { CancellationReason = "abandoned" }); + await setupIntentCache.Remove(provider.Id); + break; + } + case PaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): + { + await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId); + break; + } + } + } } } @@ -580,18 +674,38 @@ public class ProviderBillingService( }); } + var requireProviderPaymentMethodDuringSetup = + featureService.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup); + + var setupIntentId = await setupIntentCache.Get(provider.Id); + + var setupIntent = !string.IsNullOrEmpty(setupIntentId) + ? await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + { + Expand = ["payment_method"] + }) + : null; + + var usePaymentMethod = + requireProviderPaymentMethodDuringSetup && + (!string.IsNullOrEmpty(customer.InvoiceSettings.DefaultPaymentMethodId) || + customer.Metadata.ContainsKey(BraintreeCustomerIdKey) || + setupIntent.IsUnverifiedBankAccount()); + var subscriptionCreateOptions = new SubscriptionCreateOptions { - CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, + CollectionMethod = usePaymentMethod ? + StripeConstants.CollectionMethod.ChargeAutomatically : StripeConstants.CollectionMethod.SendInvoice, Customer = customer.Id, - DaysUntilDue = 30, + DaysUntilDue = usePaymentMethod ? null : 30, Items = subscriptionItemOptionsList, Metadata = new Dictionary { { "providerId", provider.Id.ToString() } }, OffSession = true, - ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations + ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, + TrialPeriodDays = usePaymentMethod ? 14 : null }; if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) @@ -607,7 +721,10 @@ public class ProviderBillingService( { var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); - if (subscription.Status == StripeConstants.SubscriptionStatus.Active) + if (subscription is + { + Status: StripeConstants.SubscriptionStatus.Active or StripeConstants.SubscriptionStatus.Trialing + }) { return subscription; } diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs index d2d82f47de..c66acfa8ce 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs @@ -1,5 +1,6 @@ using Bit.Commercial.Core.AdminConsole.Services; using Bit.Commercial.Core.Test.AdminConsole.AutoFixture; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; @@ -7,6 +8,7 @@ using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Context; @@ -38,7 +40,7 @@ public class ProviderServiceTests public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) { var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default)); + () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default, null)); Assert.Contains("Invalid owner.", exception.Message); } @@ -50,12 +52,85 @@ public class ProviderServiceTests userService.GetUserByIdAsync(user.Id).Returns(user); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default)); + () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default, null)); Assert.Contains("Invalid token.", exception.Message); } [Theory, BitAutoData] - public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TaxInfo taxInfo, + public async Task CompleteSetupAsync_InvalidTaxInfo_ThrowsBadRequestException( + User user, + Provider provider, + string key, + TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource, + [ProviderUser] ProviderUser providerUser, + SutProvider sutProvider) + { + providerUser.ProviderId = provider.Id; + providerUser.UserId = user.Id; + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + + sutProvider.Create(); + + var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + taxInfo.BillingAddressCountry = null; + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource)); + + Assert.Equal("Both address and postal code are required to set up your provider.", exception.Message); + } + + [Theory, BitAutoData] + public async Task CompleteSetupAsync_InvalidTokenizedPaymentSource_ThrowsBadRequestException( + User user, + Provider provider, + string key, + TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource, + [ProviderUser] ProviderUser providerUser, + SutProvider sutProvider) + { + providerUser.ProviderId = provider.Id; + providerUser.UserId = user.Id; + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + + sutProvider.Create(); + + var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay }; + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource)); + + Assert.Equal("A payment method is required to set up your provider.", exception.Message); + } + + [Theory, BitAutoData] + public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource, [ProviderUser] ProviderUser providerUser, SutProvider sutProvider) { @@ -75,7 +150,7 @@ public class ProviderServiceTests var providerBillingService = sutProvider.GetDependency(); var customer = new Customer { Id = "customer_id" }; - providerBillingService.SetupCustomer(provider, taxInfo).Returns(customer); + providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource).Returns(customer); var subscription = new Subscription { Id = "subscription_id" }; providerBillingService.SetupSubscription(provider).Returns(subscription); @@ -84,7 +159,7 @@ public class ProviderServiceTests var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo); + await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource); await sutProvider.GetDependency().Received().UpsertAsync(Arg.Is( p => diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index 2661a0eff6..1862692087 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -2,14 +2,17 @@ using System.Net; using Bit.Commercial.Core.Billing; using Bit.Commercial.Core.Billing.Models; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; @@ -24,11 +27,17 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; +using Braintree; using CsvHelper; using NSubstitute; +using NSubstitute.ExceptionExtensions; using Stripe; using Xunit; using static Bit.Core.Test.Billing.Utilities; +using Address = Stripe.Address; +using Customer = Stripe.Customer; +using PaymentMethod = Stripe.PaymentMethod; +using Subscription = Stripe.Subscription; namespace Bit.Commercial.Core.Test.Billing; @@ -833,7 +842,7 @@ public class ProviderBillingServiceTests } [Theory, BitAutoData] - public async Task SetupCustomer_Success( + public async Task SetupCustomer_NoPaymentMethod_Success( SutProvider sutProvider, Provider provider, TaxInfo taxInfo) @@ -877,6 +886,301 @@ public class ProviderBillingServiceTests Assert.Equivalent(expected, actual); } + [Theory, BitAutoData] + public async Task SetupCustomer_InvalidRequiredPaymentMethod_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay }; + + await ThrowsBillingExceptionAsync(() => + sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + } + + [Theory, BitAutoData] + public async Task SetupCustomer_WithBankAccount_Error_Reverts( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + stripeAdapter.SetupIntentList(Arg.Is(options => + options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([ + new SetupIntent { Id = "setup_intent_id" } + ]); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Throws(); + + sutProvider.GetDependency().Get(provider.Id).Returns("setup_intent_id"); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + + await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_id"); + + await stripeAdapter.Received(1).SetupIntentCancel("setup_intent_id", Arg.Is(options => + options.CancellationReason == "abandoned")); + + await sutProvider.GetDependency().Received(1).Remove(provider.Id); + } + + [Theory, BitAutoData] + public async Task SetupCustomer_WithPayPal_Error_Reverts( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token) + .Returns("braintree_customer_id"); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.Metadata["btCustomerId"] == "braintree_customer_id" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Throws(); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + + await sutProvider.GetDependency().Customer.Received(1).DeleteAsync("braintree_customer_id"); + } + + [Theory, BitAutoData] + public async Task SetupCustomer_WithBankAccount_Success( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var expected = new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + stripeAdapter.SetupIntentList(Arg.Is(options => + options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([ + new SetupIntent { Id = "setup_intent_id" } + ]); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Returns(expected); + + var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + + Assert.Equivalent(expected, actual); + + await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_id"); + } + + [Theory, BitAutoData] + public async Task SetupCustomer_WithPayPal_Success( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var expected = new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token) + .Returns("braintree_customer_id"); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.Metadata["btCustomerId"] == "braintree_customer_id" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Returns(expected); + + var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + + Assert.Equivalent(expected, actual); + } + + [Theory, BitAutoData] + public async Task SetupCustomer_WithCard_Success( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var expected = new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.PaymentMethod == tokenizedPaymentSource.Token && + o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Returns(expected); + + var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + + Assert.Equivalent(expected, actual); + } + [Theory, BitAutoData] public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid( SutProvider sutProvider, @@ -1044,7 +1348,7 @@ public class ProviderBillingServiceTests } [Theory, BitAutoData] - public async Task SetupSubscription_Succeeds( + public async Task SetupSubscription_SendInvoice_Succeeds( SutProvider sutProvider, Provider provider) { @@ -1127,6 +1431,303 @@ public class ProviderBillingServiceTests Assert.Equivalent(expected, actual); } + [Theory, BitAutoData] + public async Task SetupSubscription_ChargeAutomatically_HasCard_Succeeds( + SutProvider sutProvider, + Provider provider) + { + provider.Type = ProviderType.Msp; + provider.GatewaySubscriptionId = null; + + var customer = new Customer + { + Id = "customer_id", + InvoiceSettings = new CustomerInvoiceSettings + { + DefaultPaymentMethodId = "pm_123" + }, + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer); + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + }, + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + } + }; + + foreach (var plan in providerPlans) + { + sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) + .Returns(StaticStore.GetPlan(plan.PlanType)); + } + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; + + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( + Arg.Is(options => + options.Customer == "customer_id") + , Arg.Is(p => p == customer))) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && + sub.Customer == "customer_id" && + sub.DaysUntilDue == null && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + sub.TrialPeriodDays == 14)).Returns(expected); + + var actual = await sutProvider.Sut.SetupSubscription(provider); + + Assert.Equivalent(expected, actual); + } + + [Theory, BitAutoData] + public async Task SetupSubscription_ChargeAutomatically_HasBankAccount_Succeeds( + SutProvider sutProvider, + Provider provider) + { + provider.Type = ProviderType.Msp; + provider.GatewaySubscriptionId = null; + + var customer = new Customer + { + Id = "customer_id", + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary(), + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer); + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + }, + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + } + }; + + foreach (var plan in providerPlans) + { + sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) + .Returns(StaticStore.GetPlan(plan.PlanType)); + } + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; + + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( + Arg.Is(options => + options.Customer == "customer_id") + , Arg.Is(p => p == customer))) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + const string setupIntentId = "seti_123"; + + sutProvider.GetDependency().Get(provider.Id).Returns(setupIntentId); + + sutProvider.GetDependency().SetupIntentGet(setupIntentId, Arg.Is(options => + options.Expand.Contains("payment_method"))).Returns(new SetupIntent + { + Id = setupIntentId, + Status = "requires_action", + NextAction = new SetupIntentNextAction + { + VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits() + }, + PaymentMethod = new PaymentMethod + { + UsBankAccount = new PaymentMethodUsBankAccount() + } + }); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && + sub.Customer == "customer_id" && + sub.DaysUntilDue == null && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + sub.TrialPeriodDays == 14)).Returns(expected); + + var actual = await sutProvider.Sut.SetupSubscription(provider); + + Assert.Equivalent(expected, actual); + } + + [Theory, BitAutoData] + public async Task SetupSubscription_ChargeAutomatically_HasPayPal_Succeeds( + SutProvider sutProvider, + Provider provider) + { + provider.Type = ProviderType.Msp; + provider.GatewaySubscriptionId = null; + + var customer = new Customer + { + Id = "customer_id", + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary + { + ["btCustomerId"] = "braintree_customer_id" + }, + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer); + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + }, + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + } + }; + + foreach (var plan in providerPlans) + { + sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) + .Returns(StaticStore.GetPlan(plan.PlanType)); + } + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; + + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( + Arg.Is(options => + options.Customer == "customer_id") + , Arg.Is(p => p == customer))) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && + sub.Customer == "customer_id" && + sub.DaysUntilDue == null && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + sub.TrialPeriodDays == 14)).Returns(expected); + + var actual = await sutProvider.Sut.SetupSubscription(provider); + + Assert.Equivalent(expected, actual); + } + #endregion #region UpdateSeatMinimums diff --git a/src/Api/AdminConsole/Controllers/ProvidersController.cs b/src/Api/AdminConsole/Controllers/ProvidersController.cs index be119744b3..b6933da0c9 100644 --- a/src/Api/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Api/AdminConsole/Controllers/ProvidersController.cs @@ -84,22 +84,22 @@ public class ProvidersController : Controller var userId = _userService.GetProperUserId(User).Value; - var taxInfo = model.TaxInfo != null - ? new TaxInfo - { - BillingAddressCountry = model.TaxInfo.Country, - BillingAddressPostalCode = model.TaxInfo.PostalCode, - TaxIdNumber = model.TaxInfo.TaxId, - BillingAddressLine1 = model.TaxInfo.Line1, - BillingAddressLine2 = model.TaxInfo.Line2, - BillingAddressCity = model.TaxInfo.City, - BillingAddressState = model.TaxInfo.State - } - : null; + var taxInfo = new TaxInfo + { + BillingAddressCountry = model.TaxInfo.Country, + BillingAddressPostalCode = model.TaxInfo.PostalCode, + TaxIdNumber = model.TaxInfo.TaxId, + BillingAddressLine1 = model.TaxInfo.Line1, + BillingAddressLine2 = model.TaxInfo.Line2, + BillingAddressCity = model.TaxInfo.City, + BillingAddressState = model.TaxInfo.State + }; + + var tokenizedPaymentSource = model.PaymentSource?.ToDomain(); var response = await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key, - taxInfo); + taxInfo, tokenizedPaymentSource); return new ProviderResponseModel(response); } diff --git a/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs b/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs index 5e10807c69..697077c9b6 100644 --- a/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs @@ -1,5 +1,6 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json.Serialization; +using Bit.Api.Billing.Models.Requests; using Bit.Api.Models.Request; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Utilities; @@ -23,7 +24,9 @@ public class ProviderSetupRequestModel public string Token { get; set; } [Required] public string Key { get; set; } + [Required] public ExpandedTaxInfoUpdateRequestModel TaxInfo { get; set; } + public TokenizedPaymentSourceRequestBody PaymentSource { get; set; } public virtual Provider ToProvider(Provider provider) { diff --git a/src/Core/AdminConsole/Services/IProviderService.cs b/src/Core/AdminConsole/Services/IProviderService.cs index 8999b3cb81..e4b6f3aabd 100644 --- a/src/Core/AdminConsole/Services/IProviderService.cs +++ b/src/Core/AdminConsole/Services/IProviderService.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; +using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Models.Business; @@ -7,7 +8,8 @@ namespace Bit.Core.AdminConsole.Services; public interface IProviderService { - Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null); + Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource = null); Task UpdateAsync(Provider provider, bool updateBilling = false); Task> InviteUserAsync(ProviderUserInvite invite); diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs b/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs index bd3a757663..94c1096b58 100644 --- a/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs +++ b/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; +using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Models.Business; @@ -7,7 +8,7 @@ namespace Bit.Core.AdminConsole.Services.NoopImplementations; public class NoopProviderService : IProviderService { - public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null) => throw new NotImplementedException(); + public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) => throw new NotImplementedException(); public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); diff --git a/src/Core/Billing/Services/IProviderBillingService.cs b/src/Core/Billing/Services/IProviderBillingService.cs index 64585f3361..6ed8910dd8 100644 --- a/src/Core/Billing/Services/IProviderBillingService.cs +++ b/src/Core/Billing/Services/IProviderBillingService.cs @@ -79,10 +79,12 @@ public interface IProviderBillingService /// /// The to create a Stripe customer for. /// The to use for calculating the customer's automatic tax. + /// The (ex. Credit Card) to attach to the customer. /// The newly created for the . Task SetupCustomer( Provider provider, - TaxInfo taxInfo); + TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource = null); /// /// For use during the provider setup process, this method starts a Stripe for the given . diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index a05b89a94f..0d4c105b2d 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -149,6 +149,7 @@ public static class FeatureFlagKeys public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; public const string PM18770_EnableOrganizationBusinessUnitConversion = "pm-18770-enable-organization-business-unit-conversion"; public const string PM199566_UpdateMSPToChargeAutomatically = "pm-199566-update-msp-to-charge-automatically"; + public const string PM19956_RequireProviderPaymentMethodDuringSetup = "pm-19956-require-provider-payment-method-during-setup"; /* Data Insights and Reporting Team */ public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application"; From 706d7a5768be5ba40dec5c9bc64f2030a4f0a20e Mon Sep 17 00:00:00 2001 From: Matt Bishop Date: Thu, 1 May 2025 10:08:39 -0700 Subject: [PATCH 04/11] Migrate to new LD Action for code references (#5759) --- .github/workflows/code-references.yml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/code-references.yml b/.github/workflows/code-references.yml index ce8cb8e467..676a747017 100644 --- a/.github/workflows/code-references.yml +++ b/.github/workflows/code-references.yml @@ -37,12 +37,10 @@ jobs: - name: Collect id: collect - uses: launchdarkly/find-code-references-in-pull-request@30f4c4ab2949bbf258b797ced2fbf6dea34df9ce # v2.1.0 + uses: launchdarkly/find-code-references@e3e9da201b87ada54eb4c550c14fb783385c5c8a # v2.13.0 with: - project-key: default - environment-key: dev - access-token: ${{ secrets.LD_ACCESS_TOKEN }} - repo-token: ${{ secrets.GITHUB_TOKEN }} + accessToken: ${{ secrets.LD_ACCESS_TOKEN }} + projKey: default - name: Add label if: steps.collect.outputs.any-changed == 'true' From 0fa6962d1784a39734bc41264761eaf75708ae94 Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Thu, 1 May 2025 13:39:04 -0400 Subject: [PATCH 05/11] Register EF OrganizationInstallationRepository (#5751) --- .../EntityFrameworkServiceCollectionExtensions.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index ad6c7cf369..c9f0406a58 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -103,6 +103,7 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); if (selfHosted) { From 011298c9ff0a873ba696e73ad5bfe0cbc6d32c65 Mon Sep 17 00:00:00 2001 From: Jonas Hendrickx Date: Thu, 1 May 2025 19:53:03 +0200 Subject: [PATCH 06/11] PM-16517: Create personal use plan for additional storage (#5205) * PM-16517: Create personal use plan for additional storage * f * f * f * fix * f --------- Co-authored-by: Jonas Hendrickx Co-authored-by: Conner Turnbull <133619638+cturnbull-bitwarden@users.noreply.github.com> --- src/Core/Billing/Constants/StripeConstants.cs | 4 ++++ .../Billing/Models/StaticStore/Plans/Families2019Plan.cs | 2 +- src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs | 2 +- .../Services/Implementations/PremiumUserBillingService.cs | 2 +- src/Core/Services/Implementations/UserService.cs | 6 +++--- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index 8a4303e378..b5c2794d22 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -2,6 +2,10 @@ public static class StripeConstants { + public static class Prices + { + public const string StoragePlanPersonal = "personal-storage-gb-annually"; + } public static class AutomaticTaxStatus { public const string Failed = "failed"; diff --git a/src/Core/Billing/Models/StaticStore/Plans/Families2019Plan.cs b/src/Core/Billing/Models/StaticStore/Plans/Families2019Plan.cs index b0ca8feeb0..93ab2c39a1 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Families2019Plan.cs +++ b/src/Core/Billing/Models/StaticStore/Plans/Families2019Plan.cs @@ -38,7 +38,7 @@ public record Families2019Plan : Plan HasPremiumAccessOption = true; StripePlanId = "personal-org-annually"; - StripeStoragePlanId = "storage-gb-annually"; + StripeStoragePlanId = "personal-storage-gb-annually"; StripePremiumAccessPlanId = "personal-org-premium-access-annually"; BasePrice = 12; AdditionalStoragePricePerGb = 4; diff --git a/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs b/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs index e2f51ec913..8c71e50fa4 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs +++ b/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs @@ -37,7 +37,7 @@ public record FamiliesPlan : Plan HasAdditionalStorageOption = true; StripePlanId = "2020-families-org-annually"; - StripeStoragePlanId = "storage-gb-annually"; + StripeStoragePlanId = "personal-storage-gb-annually"; BasePrice = 40; AdditionalStoragePricePerGb = 4; diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index 6746a8cc98..cbd4dbbdff 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -313,7 +313,7 @@ public class PremiumUserBillingService( { subscriptionItemOptionsList.Add(new SubscriptionItemOptions { - Price = "storage-gb-annually", + Price = StripeConstants.Prices.StoragePlanPersonal, Quantity = storage }); } diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index de0fa427ba..95ee4544fa 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -11,6 +11,7 @@ using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Services; @@ -45,7 +46,6 @@ namespace Bit.Core.Services; public class UserService : UserManager, IUserService, IDisposable { private const string PremiumPlanId = "premium-annually"; - private const string StoragePlanId = "storage-gb-annually"; private readonly IUserRepository _userRepository; private readonly ICipherRepository _cipherRepository; @@ -1106,12 +1106,12 @@ public class UserService : UserManager, IUserService, IDisposable } var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, - StoragePlanId); + StripeConstants.Prices.StoragePlanPersonal); await _referenceEventService.RaiseEventAsync( new ReferenceEvent(ReferenceEventType.AdjustStorage, user, _currentContext) { Storage = storageAdjustmentGb, - PlanName = StoragePlanId, + PlanName = StripeConstants.Prices.StoragePlanPersonal, }); await SaveUserAsync(user); return secret; From 9da98d8e974b9e57468e1514b5c7820b22c755ed Mon Sep 17 00:00:00 2001 From: Matt Bishop Date: Thu, 1 May 2025 12:25:52 -0700 Subject: [PATCH 07/11] Run LD reference check on all pushes (#5760) * Run LD reference check on all pushes * Fix syntax of code-references.yml --------- Co-authored-by: Matt Andreko --- .github/workflows/code-references.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-references.yml b/.github/workflows/code-references.yml index 676a747017..30fbff32ed 100644 --- a/.github/workflows/code-references.yml +++ b/.github/workflows/code-references.yml @@ -1,7 +1,10 @@ name: Collect code references -on: - pull_request: +on: + push: +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true jobs: check-ld-secret: From 41001fefaeacaaf9a63740711f36fee27445fa0d Mon Sep 17 00:00:00 2001 From: Thomas Rittson <31796059+eliykat@users.noreply.github.com> Date: Fri, 2 May 2025 07:00:48 +1000 Subject: [PATCH 08/11] Support use of organizationId parameter in authorization (#5758) --- .../Authorization/HttpContextExtensions.cs | 20 ++++++--- .../HttpContextExtensionsTests.cs | 42 ++++++++++++++++++- 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/src/Api/AdminConsole/Authorization/HttpContextExtensions.cs b/src/Api/AdminConsole/Authorization/HttpContextExtensions.cs index ba00ea6c18..accb9539fa 100644 --- a/src/Api/AdminConsole/Authorization/HttpContextExtensions.cs +++ b/src/Api/AdminConsole/Authorization/HttpContextExtensions.cs @@ -9,7 +9,7 @@ namespace Bit.Api.AdminConsole.Authorization; public static class HttpContextExtensions { public const string NoOrgIdError = - "A route decorated with with '[Authorize]' must include a route value named 'orgId' either through the [Controller] attribute or through a '[Http*]' attribute."; + "A route decorated with with '[Authorize]' must include a route value named 'orgId' or 'organizationId' either through the [Controller] attribute or through a '[Http*]' attribute."; /// /// Returns the result of the callback, caching it in HttpContext.Features for the lifetime of the request. @@ -61,19 +61,27 @@ public static class HttpContextExtensions /// - /// Parses the {orgId} route parameter into a Guid, or throws if the {orgId} is not present or not a valid guid. + /// Parses the {orgId} or {organizationId} route parameter into a Guid, or throws if neither are present or are not valid guids. /// /// /// /// public static Guid GetOrganizationId(this HttpContext httpContext) { - httpContext.GetRouteData().Values.TryGetValue("orgId", out var orgIdParam); - if (orgIdParam == null || !Guid.TryParse(orgIdParam.ToString(), out var orgId)) + var routeValues = httpContext.GetRouteData().Values; + + routeValues.TryGetValue("orgId", out var orgIdParam); + if (orgIdParam != null && Guid.TryParse(orgIdParam.ToString(), out var orgId)) { - throw new InvalidOperationException(NoOrgIdError); + return orgId; } - return orgId; + routeValues.TryGetValue("organizationId", out var organizationIdParam); + if (organizationIdParam != null && Guid.TryParse(organizationIdParam.ToString(), out var organizationId)) + { + return organizationId; + } + + throw new InvalidOperationException(NoOrgIdError); } } diff --git a/test/Api.Test/AdminConsole/Authorization/HttpContextExtensionsTests.cs b/test/Api.Test/AdminConsole/Authorization/HttpContextExtensionsTests.cs index 1901742777..428726aaac 100644 --- a/test/Api.Test/AdminConsole/Authorization/HttpContextExtensionsTests.cs +++ b/test/Api.Test/AdminConsole/Authorization/HttpContextExtensionsTests.cs @@ -1,5 +1,7 @@ -using Bit.Api.AdminConsole.Authorization; +using AutoFixture.Xunit2; +using Bit.Api.AdminConsole.Authorization; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; using NSubstitute; using Xunit; @@ -25,4 +27,42 @@ public class HttpContextExtensionsTests await callback.ReceivedWithAnyArgs(1).Invoke(); } + [Theory] + [InlineAutoData("orgId")] + [InlineAutoData("organizationId")] + public void GetOrganizationId_GivenValidParameter_ReturnsOrganizationId(string paramName, Guid orgId) + { + var httpContext = new DefaultHttpContext + { + Request = { RouteValues = new RouteValueDictionary + { + { "userId", "someGuid" }, + { paramName, orgId.ToString() } + } + } + }; + + var result = httpContext.GetOrganizationId(); + Assert.Equal(orgId, result); + } + + [Theory] + [InlineAutoData("orgId")] + [InlineAutoData("organizationId")] + [InlineAutoData("missingParameter")] + public void GetOrganizationId_GivenMissingOrInvalidGuid_Throws(string paramName) + { + var httpContext = new DefaultHttpContext + { + Request = { RouteValues = new RouteValueDictionary + { + { "userId", "someGuid" }, + { paramName, "invalidGuid" } + } + } + }; + + var exception = Assert.Throws(() => httpContext.GetOrganizationId()); + Assert.Equal(HttpContextExtensions.NoOrgIdError, exception.Message); + } } From 2d4ec530c5c3638cbc3c7ddb286ba3442cd03014 Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Thu, 1 May 2025 17:13:10 -0400 Subject: [PATCH 09/11] [PM-18955] Implement `OrganizationWarningsQuery` (#5713) * Add GetWarnings endpoint to OrganizationBillingController * Add OrganizationWarningsQueryTests --- .../OrganizationBillingController.cs | 26 ++ .../OrganizationWarningsResponse.cs | 43 +++ .../OrganizationWarningsQuery.cs | 214 ++++++++++++ src/Api/Billing/Registrations.cs | 11 + src/Api/Startup.cs | 3 + src/Core/Billing/Constants/StripeConstants.cs | 2 + src/Core/Constants.cs | 1 + .../OrganizationWarningsQueryTests.cs | 315 ++++++++++++++++++ 8 files changed, 615 insertions(+) create mode 100644 src/Api/Billing/Models/Responses/Organizations/OrganizationWarningsResponse.cs create mode 100644 src/Api/Billing/Queries/Organizations/OrganizationWarningsQuery.cs create mode 100644 src/Api/Billing/Registrations.cs create mode 100644 test/Api.Test/Billing/Queries/Organizations/OrganizationWarningsQueryTests.cs diff --git a/src/Api/Billing/Controllers/OrganizationBillingController.cs b/src/Api/Billing/Controllers/OrganizationBillingController.cs index 1d4ebc1511..2f0a4ef48b 100644 --- a/src/Api/Billing/Controllers/OrganizationBillingController.cs +++ b/src/Api/Billing/Controllers/OrganizationBillingController.cs @@ -2,6 +2,7 @@ using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Responses; +using Bit.Api.Billing.Queries.Organizations; using Bit.Core; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; @@ -24,6 +25,7 @@ public class OrganizationBillingController( IFeatureService featureService, IOrganizationBillingService organizationBillingService, IOrganizationRepository organizationRepository, + IOrganizationWarningsQuery organizationWarningsQuery, IPaymentService paymentService, IPricingClient pricingClient, ISubscriberService subscriberService, @@ -335,4 +337,28 @@ public class OrganizationBillingController( return TypedResults.Ok(providerId); } + + [HttpGet("warnings")] + public async Task GetWarningsAsync([FromRoute] Guid organizationId) + { + /* + * We'll keep these available at the User level, because we're hiding any pertinent information and + * we want to throw as few errors as possible since these are not core features. + */ + if (!await currentContext.OrganizationUser(organizationId)) + { + return Error.Unauthorized(); + } + + var organization = await organizationRepository.GetByIdAsync(organizationId); + + if (organization == null) + { + return Error.NotFound(); + } + + var response = await organizationWarningsQuery.Run(organization); + + return TypedResults.Ok(response); + } } diff --git a/src/Api/Billing/Models/Responses/Organizations/OrganizationWarningsResponse.cs b/src/Api/Billing/Models/Responses/Organizations/OrganizationWarningsResponse.cs new file mode 100644 index 0000000000..e124bdc318 --- /dev/null +++ b/src/Api/Billing/Models/Responses/Organizations/OrganizationWarningsResponse.cs @@ -0,0 +1,43 @@ +#nullable enable +namespace Bit.Api.Billing.Models.Responses.Organizations; + +public record OrganizationWarningsResponse +{ + public FreeTrialWarning? FreeTrial { get; set; } + public InactiveSubscriptionWarning? InactiveSubscription { get; set; } + public ResellerRenewalWarning? ResellerRenewal { get; set; } + + public record FreeTrialWarning + { + public int RemainingTrialDays { get; set; } + } + + public record InactiveSubscriptionWarning + { + public required string Resolution { get; set; } + } + + public record ResellerRenewalWarning + { + public required string Type { get; set; } + public UpcomingRenewal? Upcoming { get; set; } + public IssuedRenewal? Issued { get; set; } + public PastDueRenewal? PastDue { get; set; } + + public record UpcomingRenewal + { + public required DateTime RenewalDate { get; set; } + } + + public record IssuedRenewal + { + public required DateTime IssuedDate { get; set; } + public required DateTime DueDate { get; set; } + } + + public record PastDueRenewal + { + public required DateTime SuspensionDate { get; set; } + } + } +} diff --git a/src/Api/Billing/Queries/Organizations/OrganizationWarningsQuery.cs b/src/Api/Billing/Queries/Organizations/OrganizationWarningsQuery.cs new file mode 100644 index 0000000000..f6a0e5b1e6 --- /dev/null +++ b/src/Api/Billing/Queries/Organizations/OrganizationWarningsQuery.cs @@ -0,0 +1,214 @@ +// ReSharper disable InconsistentNaming + +#nullable enable + +using Bit.Api.Billing.Models.Responses.Organizations; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services; +using Bit.Core.Context; +using Bit.Core.Services; +using Stripe; +using FreeTrialWarning = Bit.Api.Billing.Models.Responses.Organizations.OrganizationWarningsResponse.FreeTrialWarning; +using InactiveSubscriptionWarning = + Bit.Api.Billing.Models.Responses.Organizations.OrganizationWarningsResponse.InactiveSubscriptionWarning; +using ResellerRenewalWarning = + Bit.Api.Billing.Models.Responses.Organizations.OrganizationWarningsResponse.ResellerRenewalWarning; + +namespace Bit.Api.Billing.Queries.Organizations; + +public interface IOrganizationWarningsQuery +{ + Task Run( + Organization organization); +} + +public class OrganizationWarningsQuery( + ICurrentContext currentContext, + IProviderRepository providerRepository, + IStripeAdapter stripeAdapter, + ISubscriberService subscriberService) : IOrganizationWarningsQuery +{ + public async Task Run( + Organization organization) + { + var response = new OrganizationWarningsResponse(); + + var subscription = + await subscriberService.GetSubscription(organization, + new SubscriptionGetOptions { Expand = ["customer", "latest_invoice", "test_clock"] }); + + if (subscription == null) + { + return response; + } + + response.FreeTrial = await GetFreeTrialWarning(organization, subscription); + + var provider = await providerRepository.GetByOrganizationIdAsync(organization.Id); + + response.InactiveSubscription = await GetInactiveSubscriptionWarning(organization, provider, subscription); + + response.ResellerRenewal = await GetResellerRenewalWarning(provider, subscription); + + return response; + } + + private async Task GetFreeTrialWarning( + Organization organization, + Subscription subscription) + { + if (!await currentContext.EditSubscription(organization.Id)) + { + return null; + } + + if (subscription is not + { + Status: StripeConstants.SubscriptionStatus.Trialing, + TrialEnd: not null, + Customer: not null + }) + { + return null; + } + + var customer = subscription.Customer; + + var hasPaymentMethod = + !string.IsNullOrEmpty(customer.InvoiceSettings.DefaultPaymentMethodId) || + !string.IsNullOrEmpty(customer.DefaultSourceId) || + customer.Metadata.ContainsKey(StripeConstants.MetadataKeys.BraintreeCustomerId); + + if (hasPaymentMethod) + { + return null; + } + + var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow; + + var remainingTrialDays = (int)Math.Ceiling((subscription.TrialEnd.Value - now).TotalDays); + + return new FreeTrialWarning { RemainingTrialDays = remainingTrialDays }; + } + + private async Task GetInactiveSubscriptionWarning( + Organization organization, + Provider? provider, + Subscription subscription) + { + if (organization.Enabled || + subscription.Status is not StripeConstants.SubscriptionStatus.Unpaid + and not StripeConstants.SubscriptionStatus.Canceled) + { + return null; + } + + if (provider != null) + { + return new InactiveSubscriptionWarning { Resolution = "contact_provider" }; + } + + if (await currentContext.OrganizationOwner(organization.Id)) + { + return subscription.Status switch + { + StripeConstants.SubscriptionStatus.Unpaid => new InactiveSubscriptionWarning + { + Resolution = "add_payment_method" + }, + StripeConstants.SubscriptionStatus.Canceled => new InactiveSubscriptionWarning + { + Resolution = "resubscribe" + }, + _ => null + }; + } + + return new InactiveSubscriptionWarning { Resolution = "contact_owner" }; + } + + private async Task GetResellerRenewalWarning( + Provider? provider, + Subscription subscription) + { + if (provider is not + { + Type: ProviderType.Reseller + }) + { + return null; + } + + if (subscription.CollectionMethod != StripeConstants.CollectionMethod.SendInvoice) + { + return null; + } + + var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow; + + // ReSharper disable once ConvertIfStatementToSwitchStatement + if (subscription is + { + Status: StripeConstants.SubscriptionStatus.Trialing or StripeConstants.SubscriptionStatus.Active, + LatestInvoice: null or { Status: StripeConstants.InvoiceStatus.Paid } + } && (subscription.CurrentPeriodEnd - now).TotalDays <= 14) + { + return new ResellerRenewalWarning + { + Type = "upcoming", + Upcoming = new ResellerRenewalWarning.UpcomingRenewal + { + RenewalDate = subscription.CurrentPeriodEnd + } + }; + } + + if (subscription is + { + Status: StripeConstants.SubscriptionStatus.Active, + LatestInvoice: { Status: StripeConstants.InvoiceStatus.Open, DueDate: not null } + } && subscription.LatestInvoice.DueDate > now) + { + return new ResellerRenewalWarning + { + Type = "issued", + Issued = new ResellerRenewalWarning.IssuedRenewal + { + IssuedDate = subscription.LatestInvoice.Created, + DueDate = subscription.LatestInvoice.DueDate.Value + } + }; + } + + // ReSharper disable once InvertIf + if (subscription.Status == StripeConstants.SubscriptionStatus.PastDue) + { + var openInvoices = await stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + { + Query = $"subscription:'{subscription.Id}' status:'open'" + }); + + var earliestOverdueInvoice = openInvoices + .Where(invoice => invoice.DueDate != null && invoice.DueDate < now) + .MinBy(invoice => invoice.Created); + + if (earliestOverdueInvoice != null) + { + return new ResellerRenewalWarning + { + Type = "past_due", + PastDue = new ResellerRenewalWarning.PastDueRenewal + { + SuspensionDate = earliestOverdueInvoice.DueDate!.Value.AddDays(30) + } + }; + } + } + + return null; + } +} diff --git a/src/Api/Billing/Registrations.cs b/src/Api/Billing/Registrations.cs new file mode 100644 index 0000000000..cb92098333 --- /dev/null +++ b/src/Api/Billing/Registrations.cs @@ -0,0 +1,11 @@ +using Bit.Api.Billing.Queries.Organizations; + +namespace Bit.Api.Billing; + +public static class Registrations +{ + public static void AddBillingQueries(this IServiceCollection services) + { + services.AddTransient(); + } +} diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index 40448f722d..1cc371ae1b 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -27,6 +27,7 @@ using Bit.Core.OrganizationFeatures.OrganizationSubscriptions; using Bit.Core.Tools.Entities; using Bit.Core.Vault.Entities; using Bit.Api.Auth.Models.Request.WebAuthn; +using Bit.Api.Billing; using Bit.Core.AdminConsole.Services.NoopImplementations; using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Identity.TokenProviders; @@ -184,6 +185,8 @@ public class Startup services.AddImportServices(); services.AddPhishingDomainServices(globalSettings); + services.AddBillingQueries(); + // Authorization Handlers services.AddAuthorizationHandlers(); diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index b5c2794d22..c3e3ec6c30 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -46,10 +46,12 @@ public static class StripeConstants { public const string Draft = "draft"; public const string Open = "open"; + public const string Paid = "paid"; } public static class MetadataKeys { + public const string BraintreeCustomerId = "btCustomerId"; public const string InvoiceApproved = "invoice_approved"; public const string OrganizationId = "organizationId"; public const string ProviderId = "providerId"; diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 0d4c105b2d..13d0bad495 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -150,6 +150,7 @@ public static class FeatureFlagKeys public const string PM18770_EnableOrganizationBusinessUnitConversion = "pm-18770-enable-organization-business-unit-conversion"; public const string PM199566_UpdateMSPToChargeAutomatically = "pm-199566-update-msp-to-charge-automatically"; public const string PM19956_RequireProviderPaymentMethodDuringSetup = "pm-19956-require-provider-payment-method-during-setup"; + public const string UseOrganizationWarningsService = "use-organization-warnings-service"; /* Data Insights and Reporting Team */ public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application"; diff --git a/test/Api.Test/Billing/Queries/Organizations/OrganizationWarningsQueryTests.cs b/test/Api.Test/Billing/Queries/Organizations/OrganizationWarningsQueryTests.cs new file mode 100644 index 0000000000..67979f506e --- /dev/null +++ b/test/Api.Test/Billing/Queries/Organizations/OrganizationWarningsQueryTests.cs @@ -0,0 +1,315 @@ +using Bit.Api.Billing.Queries.Organizations; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services; +using Bit.Core.Context; +using Bit.Core.Services; +using Bit.Infrastructure.EntityFramework.AdminConsole.Models.Provider; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Stripe.TestHelpers; +using Xunit; + +namespace Bit.Api.Test.Billing.Queries.Organizations; + +[SutProviderCustomize] +public class OrganizationWarningsQueryTests +{ + private static readonly string[] _requiredExpansions = ["customer", "latest_invoice", "test_clock"]; + + [Theory, BitAutoData] + public async Task Run_NoSubscription_NoWarnings( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.SequenceEqual(_requiredExpansions) + )) + .ReturnsNull(); + + var response = await sutProvider.Sut.Run(organization); + + Assert.True(response is + { + FreeTrial: null, + InactiveSubscription: null, + ResellerRenewal: null + }); + } + + [Theory, BitAutoData] + public async Task Run_Has_FreeTrialWarning( + Organization organization, + SutProvider sutProvider) + { + var now = DateTime.UtcNow; + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.SequenceEqual(_requiredExpansions) + )) + .Returns(new Subscription + { + Status = StripeConstants.SubscriptionStatus.Trialing, + TrialEnd = now.AddDays(7), + Customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }, + TestClock = new TestClock + { + FrozenTime = now + } + }); + + sutProvider.GetDependency().EditSubscription(organization.Id).Returns(true); + + var response = await sutProvider.Sut.Run(organization); + + Assert.True(response is + { + FreeTrial.RemainingTrialDays: 7 + }); + } + + [Theory, BitAutoData] + public async Task Run_Has_InactiveSubscriptionWarning_ContactProvider( + Organization organization, + SutProvider sutProvider) + { + organization.Enabled = false; + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.SequenceEqual(_requiredExpansions) + )) + .Returns(new Subscription + { + Status = StripeConstants.SubscriptionStatus.Unpaid + }); + + sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id) + .Returns(new Provider()); + + var response = await sutProvider.Sut.Run(organization); + + Assert.True(response is + { + InactiveSubscription.Resolution: "contact_provider" + }); + } + + [Theory, BitAutoData] + public async Task Run_Has_InactiveSubscriptionWarning_AddPaymentMethod( + Organization organization, + SutProvider sutProvider) + { + organization.Enabled = false; + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.SequenceEqual(_requiredExpansions) + )) + .Returns(new Subscription + { + Status = StripeConstants.SubscriptionStatus.Unpaid + }); + + sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); + + var response = await sutProvider.Sut.Run(organization); + + Assert.True(response is + { + InactiveSubscription.Resolution: "add_payment_method" + }); + } + + [Theory, BitAutoData] + public async Task Run_Has_InactiveSubscriptionWarning_Resubscribe( + Organization organization, + SutProvider sutProvider) + { + organization.Enabled = false; + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.SequenceEqual(_requiredExpansions) + )) + .Returns(new Subscription + { + Status = StripeConstants.SubscriptionStatus.Canceled + }); + + sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); + + var response = await sutProvider.Sut.Run(organization); + + Assert.True(response is + { + InactiveSubscription.Resolution: "resubscribe" + }); + } + + [Theory, BitAutoData] + public async Task Run_Has_InactiveSubscriptionWarning_ContactOwner( + Organization organization, + SutProvider sutProvider) + { + organization.Enabled = false; + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.SequenceEqual(_requiredExpansions) + )) + .Returns(new Subscription + { + Status = StripeConstants.SubscriptionStatus.Unpaid + }); + + sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(false); + + var response = await sutProvider.Sut.Run(organization); + + Assert.True(response is + { + InactiveSubscription.Resolution: "contact_owner" + }); + } + + [Theory, BitAutoData] + public async Task Run_Has_ResellerRenewalWarning_Upcoming( + Organization organization, + SutProvider sutProvider) + { + var now = DateTime.UtcNow; + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.SequenceEqual(_requiredExpansions) + )) + .Returns(new Subscription + { + CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, + Status = StripeConstants.SubscriptionStatus.Active, + CurrentPeriodEnd = now.AddDays(10), + TestClock = new TestClock + { + FrozenTime = now + } + }); + + sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id) + .Returns(new Provider + { + Type = ProviderType.Reseller + }); + + var response = await sutProvider.Sut.Run(organization); + + Assert.True(response is + { + ResellerRenewal.Type: "upcoming" + }); + + Assert.Equal(now.AddDays(10), response.ResellerRenewal.Upcoming!.RenewalDate); + } + + [Theory, BitAutoData] + public async Task Run_Has_ResellerRenewalWarning_Issued( + Organization organization, + SutProvider sutProvider) + { + var now = DateTime.UtcNow; + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.SequenceEqual(_requiredExpansions) + )) + .Returns(new Subscription + { + CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, + Status = StripeConstants.SubscriptionStatus.Active, + LatestInvoice = new Invoice + { + Status = StripeConstants.InvoiceStatus.Open, + DueDate = now.AddDays(30), + Created = now + }, + TestClock = new TestClock + { + FrozenTime = now + } + }); + + sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id) + .Returns(new Provider + { + Type = ProviderType.Reseller + }); + + var response = await sutProvider.Sut.Run(organization); + + Assert.True(response is + { + ResellerRenewal.Type: "issued" + }); + + Assert.Equal(now, response.ResellerRenewal.Issued!.IssuedDate); + Assert.Equal(now.AddDays(30), response.ResellerRenewal.Issued!.DueDate); + } + + [Theory, BitAutoData] + public async Task Run_Has_ResellerRenewalWarning_PastDue( + Organization organization, + SutProvider sutProvider) + { + var now = DateTime.UtcNow; + + const string subscriptionId = "subscription_id"; + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.SequenceEqual(_requiredExpansions) + )) + .Returns(new Subscription + { + Id = subscriptionId, + CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, + Status = StripeConstants.SubscriptionStatus.PastDue, + TestClock = new TestClock + { + FrozenTime = now + } + }); + + sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id) + .Returns(new Provider + { + Type = ProviderType.Reseller + }); + + var dueDate = now.AddDays(-10); + + sutProvider.GetDependency().InvoiceSearchAsync(Arg.Is(options => + options.Query == $"subscription:'{subscriptionId}' status:'open'")).Returns([ + new Invoice { DueDate = dueDate, Created = dueDate.AddDays(-30) } + ]); + + var response = await sutProvider.Sut.Run(organization); + + Assert.True(response is + { + ResellerRenewal.Type: "past_due" + }); + + Assert.Equal(dueDate.AddDays(30), response.ResellerRenewal.PastDue!.SuspensionDate); + } +} From cd3f16948b31367cfbc02f2a9e457cb45581d6cc Mon Sep 17 00:00:00 2001 From: Conner Turnbull <133619638+cturnbull-bitwarden@users.noreply.github.com> Date: Fri, 2 May 2025 08:25:52 -0400 Subject: [PATCH 10/11] Resolved the ambiguous build error (#5762) --- .../PhishingDomainFeatures/AzurePhishingDomainStorageService.cs | 2 +- .../PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Core/PhishingDomainFeatures/AzurePhishingDomainStorageService.cs b/src/Core/PhishingDomainFeatures/AzurePhishingDomainStorageService.cs index 9af9c94e1d..0d287a2229 100644 --- a/src/Core/PhishingDomainFeatures/AzurePhishingDomainStorageService.cs +++ b/src/Core/PhishingDomainFeatures/AzurePhishingDomainStorageService.cs @@ -39,7 +39,7 @@ public class AzurePhishingDomainStorageService var content = await streamReader.ReadToEndAsync(); return [.. content - .Split(['\r', '\n'], StringSplitOptions.RemoveEmptyEntries) + .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries) .Select(line => line.Trim()) .Where(line => !string.IsNullOrWhiteSpace(line) && !line.StartsWith('#'))]; } diff --git a/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs b/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs index b059eac0e8..420948e310 100644 --- a/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs +++ b/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs @@ -92,7 +92,7 @@ public class CloudPhishingDomainDirectQuery : ICloudPhishingDomainQuery } return content - .Split(['\r', '\n'], StringSplitOptions.RemoveEmptyEntries) + .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries) .Select(line => line.Trim()) .Where(line => !string.IsNullOrWhiteSpace(line) && !line.StartsWith("#")) .ToList(); From 077d0fa6d7268a402dcbe88e574d2c4242cd6c78 Mon Sep 17 00:00:00 2001 From: Conner Turnbull <133619638+cturnbull-bitwarden@users.noreply.github.com> Date: Fri, 2 May 2025 12:53:06 -0400 Subject: [PATCH 11/11] Resolved an issue where autoscaling always happened (#5765) --- .../Services/IOrganizationService.cs | 1 + .../Implementations/OrganizationService.cs | 2 +- .../CreateSponsorshipCommand.cs | 19 ++- .../CreateSponsorshipCommandTests.cs | 127 ++++++++++++++++++ 4 files changed, 145 insertions(+), 4 deletions(-) diff --git a/src/Core/AdminConsole/Services/IOrganizationService.cs b/src/Core/AdminConsole/Services/IOrganizationService.cs index 9c9e311a02..1e53be734e 100644 --- a/src/Core/AdminConsole/Services/IOrganizationService.cs +++ b/src/Core/AdminConsole/Services/IOrganizationService.cs @@ -49,6 +49,7 @@ public interface IOrganizationService IEnumerable organizationUserIds, Guid? revokingUserId); Task CreatePendingOrganization(Organization organization, string ownerEmail, ClaimsPrincipal user, IUserService userService, bool salesAssistedTrialStarted); Task ReplaceAndUpdateCacheAsync(Organization org, EventType? orgEvent = null); + Task<(bool canScale, string failureReason)> CanScaleAsync(Organization organization, int seatsToAdd); void ValidatePasswordManagerPlan(Models.StaticStore.Plan plan, OrganizationUpgrade upgrade); void ValidateSecretsManagerPlan(Models.StaticStore.Plan plan, OrganizationUpgrade upgrade); diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 532aebf5e0..5c7e5e29ed 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -1058,7 +1058,7 @@ public class OrganizationService : IOrganizationService organization: organization, initOrganization: initOrganization)); - internal async Task<(bool canScale, string failureReason)> CanScaleAsync( + public async Task<(bool canScale, string failureReason)> CanScaleAsync( Organization organization, int seatsToAdd) { diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs index 3b74baf6f9..b15cbea240 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs @@ -15,7 +15,8 @@ public class CreateSponsorshipCommand( ICurrentContext currentContext, IOrganizationSponsorshipRepository organizationSponsorshipRepository, IUserService userService, - IOrganizationService organizationService) : ICreateSponsorshipCommand + IOrganizationService organizationService, + IOrganizationUserRepository organizationUserRepository) : ICreateSponsorshipCommand { public async Task CreateSponsorshipAsync( Organization sponsoringOrganization, @@ -82,14 +83,26 @@ public class CreateSponsorshipCommand( if (existingOrgSponsorship != null) { - // Replace existing invalid offer with our new sponsorship offer sponsorship.Id = existingOrgSponsorship.Id; } } if (isAdminInitiated && sponsoringOrganization.Seats.HasValue) { - await organizationService.AutoAddSeatsAsync(sponsoringOrganization, 1); + var occupiedSeats = await organizationUserRepository.GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrganization.Id); + var availableSeats = sponsoringOrganization.Seats.Value - occupiedSeats; + + if (availableSeats <= 0) + { + var newSeatsRequired = 1; + var (canScale, failureReason) = await organizationService.CanScaleAsync(sponsoringOrganization, newSeatsRequired); + if (!canScale) + { + throw new BadRequestException(failureReason); + } + + await organizationService.AutoAddSeatsAsync(sponsoringOrganization, newSeatsRequired); + } } try diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs index f6b6721bd2..7dc6b7360d 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs @@ -168,6 +168,11 @@ public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase }); sutProvider.GetDependency().UserId.Returns(sponsoringOrgUser.UserId.Value); + // Setup for checking available seats + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrg.Id) + .Returns(0); + await sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName, false, null); @@ -293,6 +298,7 @@ public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase { sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; sponsoringOrg.UseAdminSponsoredFamilies = true; + sponsoringOrg.Seats = 10; sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId!.Value).Returns(user); @@ -311,6 +317,11 @@ public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase } ]); + // Setup for checking available seats - organization has plenty of seats + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrg.Id) + .Returns(5); + var actual = await sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName, true, notes); @@ -331,5 +342,121 @@ public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase await sutProvider.GetDependency().Received(1) .CreateAsync(Arg.Is(s => SponsorshipValidator(s, expectedSponsorship))); + + // Verify we didn't need to add seats + await sutProvider.GetDependency().DidNotReceive() + .AutoAddSeatsAsync(Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData(OrganizationUserType.Admin)] + [BitAutoData(OrganizationUserType.Owner)] + public async Task CreateSponsorship_CreatesAdminInitiatedSponsorship_AutoscalesWhenNeeded( + OrganizationUserType organizationUserType, + Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, User user, string sponsoredEmail, + string friendlyName, Guid sponsorshipId, Guid currentUserId, string notes, SutProvider sutProvider) + { + sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; + sponsoringOrg.UseAdminSponsoredFamilies = true; + sponsoringOrg.Seats = 10; + sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId!.Value).Returns(user); + sutProvider.GetDependency().WhenForAnyArgs(x => x.UpsertAsync(null!)).Do(callInfo => + { + var sponsorship = callInfo.Arg(); + sponsorship.Id = sponsorshipId; + }); + sutProvider.GetDependency().UserId.Returns(currentUserId); + sutProvider.GetDependency().Organizations.Returns([ + new() + { + Id = sponsoringOrg.Id, + Permissions = new Permissions { ManageUsers = true }, + Type = organizationUserType + } + ]); + + // Setup for checking available seats - organization has no available seats + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrg.Id) + .Returns(10); + + // Setup for checking if can scale + sutProvider.GetDependency() + .CanScaleAsync(sponsoringOrg, 1) + .Returns((true, "")); + + var actual = await sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, + PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName, true, notes); + + + var expectedSponsorship = new OrganizationSponsorship + { + Id = sponsorshipId, + SponsoringOrganizationId = sponsoringOrg.Id, + SponsoringOrganizationUserId = sponsoringOrgUser.Id, + FriendlyName = friendlyName, + OfferedToEmail = sponsoredEmail, + PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + IsAdminInitiated = true, + Notes = notes + }; + + Assert.True(SponsorshipValidator(expectedSponsorship, actual)); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Is(s => SponsorshipValidator(s, expectedSponsorship))); + + // Verify we needed to add seats + await sutProvider.GetDependency().Received(1) + .AutoAddSeatsAsync(sponsoringOrg, 1); + } + + [Theory] + [BitAutoData(OrganizationUserType.Admin)] + [BitAutoData(OrganizationUserType.Owner)] + public async Task CreateSponsorship_CreatesAdminInitiatedSponsorship_ThrowsWhenCannotAutoscale( + OrganizationUserType organizationUserType, + Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, User user, string sponsoredEmail, + string friendlyName, Guid sponsorshipId, Guid currentUserId, string notes, SutProvider sutProvider) + { + sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; + sponsoringOrg.UseAdminSponsoredFamilies = true; + sponsoringOrg.Seats = 10; + sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId!.Value).Returns(user); + sutProvider.GetDependency().WhenForAnyArgs(x => x.UpsertAsync(null!)).Do(callInfo => + { + var sponsorship = callInfo.Arg(); + sponsorship.Id = sponsorshipId; + }); + sutProvider.GetDependency().UserId.Returns(currentUserId); + sutProvider.GetDependency().Organizations.Returns([ + new() + { + Id = sponsoringOrg.Id, + Permissions = new Permissions { ManageUsers = true }, + Type = organizationUserType + } + ]); + + // Setup for checking available seats - organization has no available seats + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(sponsoringOrg.Id) + .Returns(10); + + // Setup for checking if can scale - cannot scale + var failureReason = "Seat limit has been reached."; + sutProvider.GetDependency() + .CanScaleAsync(sponsoringOrg, 1) + .Returns((false, failureReason)); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, + PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName, true, notes)); + + Assert.Equal(failureReason, exception.Message); } }