diff --git a/Directory.Build.props b/Directory.Build.props index 2ede6ad8d1..858abb2bc8 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2025.3.3 + 2025.4.0 Bit.$(MSBuildProjectName) enable diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs index d2acdac079..2c34e57a92 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs @@ -1,4 +1,5 @@ -using Bit.Core.AdminConsole.Entities; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.Providers.Interfaces; @@ -7,10 +8,12 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; +using Microsoft.Extensions.DependencyInjection; using Stripe; namespace Bit.Commercial.Core.AdminConsole.Providers; @@ -28,6 +31,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv private readonly ISubscriberService _subscriberService; private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery; private readonly IPricingClient _pricingClient; + private readonly IAutomaticTaxStrategy _automaticTaxStrategy; public RemoveOrganizationFromProviderCommand( IEventService eventService, @@ -40,7 +44,8 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv IProviderBillingService providerBillingService, ISubscriberService subscriberService, IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, - IPricingClient pricingClient) + IPricingClient pricingClient, + [FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy) { _eventService = eventService; _mailService = mailService; @@ -53,6 +58,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv _subscriberService = subscriberService; _hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery; _pricingClient = pricingClient; + _automaticTaxStrategy = automaticTaxStrategy; } public async Task RemoveOrganizationFromProvider( @@ -107,10 +113,11 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv organization.IsValidClient() && !string.IsNullOrEmpty(organization.GatewayCustomerId)) { - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + var customer = await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Description = string.Empty, - Email = organization.BillingEmail + Email = organization.BillingEmail, + Expand = ["tax", "tax_ids"] }); var plan = await _pricingClient.GetPlanOrThrow(organization.PlanType); @@ -120,7 +127,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv Customer = organization.GatewayCustomerId, CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, DaysUntilDue = 30, - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, Metadata = new Dictionary { { "organizationId", organization.Id.ToString() } @@ -130,6 +136,18 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }] }; + if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + _automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + } + else + { + subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + } + var subscription = await _stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); organization.GatewaySubscriptionId = subscription.Id; diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index 74cfc1f916..757d6510f1 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -14,6 +14,7 @@ using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Contracts; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -22,6 +23,7 @@ using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using CsvHelper; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; @@ -29,10 +31,10 @@ namespace Bit.Commercial.Core.Billing; public class ProviderBillingService( IEventService eventService, + IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, IOrganizationRepository organizationRepository, - IPaymentService paymentService, IPricingClient pricingClient, IProviderInvoiceItemRepository providerInvoiceItemRepository, IProviderOrganizationRepository providerOrganizationRepository, @@ -40,7 +42,9 @@ public class ProviderBillingService( IProviderUserRepository providerUserRepository, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - ITaxService taxService) : IProviderBillingService + ITaxService taxService, + [FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy) + : IProviderBillingService { [RequireFeature(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal)] public async Task AddExistingOrganization( @@ -143,36 +147,29 @@ public class ProviderBillingService( public async Task ChangePlan(ChangeProviderPlanCommand command) { - var plan = await providerPlanRepository.GetByIdAsync(command.ProviderPlanId); + var (provider, providerPlanId, newPlanType) = command; - if (plan == null) + var providerPlan = await providerPlanRepository.GetByIdAsync(providerPlanId); + + if (providerPlan == null) { throw new BadRequestException("Provider plan not found."); } - if (plan.PlanType == command.NewPlan) + if (providerPlan.PlanType == newPlanType) { return; } - var oldPlanConfiguration = await pricingClient.GetPlanOrThrow(plan.PlanType); - var newPlanConfiguration = await pricingClient.GetPlanOrThrow(command.NewPlan); + var subscription = await subscriberService.GetSubscriptionOrThrow(provider); - plan.PlanType = command.NewPlan; - await providerPlanRepository.ReplaceAsync(plan); + var oldPriceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType); + var newPriceId = ProviderPriceAdapter.GetPriceId(provider, subscription, newPlanType); - Subscription subscription; - try - { - subscription = await stripeAdapter.ProviderSubscriptionGetAsync(command.GatewaySubscriptionId, plan.ProviderId); - } - catch (InvalidOperationException) - { - throw new ConflictException("Subscription not found."); - } + providerPlan.PlanType = newPlanType; + await providerPlanRepository.ReplaceAsync(providerPlan); - var oldSubscriptionItem = subscription.Items.SingleOrDefault(x => - x.Price.Id == oldPlanConfiguration.PasswordManager.StripeProviderPortalSeatPlanId); + var oldSubscriptionItem = subscription.Items.SingleOrDefault(x => x.Price.Id == oldPriceId); var updateOptions = new SubscriptionUpdateOptions { @@ -180,7 +177,7 @@ public class ProviderBillingService( [ new SubscriptionItemOptions { - Price = newPlanConfiguration.PasswordManager.StripeProviderPortalSeatPlanId, + Price = newPriceId, Quantity = oldSubscriptionItem!.Quantity }, new SubscriptionItemOptions @@ -191,12 +188,14 @@ public class ProviderBillingService( ] }; - await stripeAdapter.SubscriptionUpdateAsync(command.GatewaySubscriptionId, updateOptions); + await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, updateOptions); // Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId) // 1. Retrieve PlanType and PlanName for ProviderPlan // 2. Assign PlanType & PlanName to Organization - var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(plan.ProviderId); + var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId); + + var newPlan = await pricingClient.GetPlanOrThrow(newPlanType); foreach (var providerOrganization in providerOrganizations) { @@ -205,8 +204,8 @@ public class ProviderBillingService( { throw new ConflictException($"Organization '{providerOrganization.Id}' not found."); } - organization.PlanType = command.NewPlan; - organization.Plan = newPlanConfiguration.Name; + organization.PlanType = newPlanType; + organization.Plan = newPlan.Name; await organizationRepository.ReplaceAsync(organization); } } @@ -400,7 +399,7 @@ public class ProviderBillingService( var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment; - var update = CurrySeatScalingUpdate( + var scaleQuantityTo = CurrySeatScalingUpdate( provider, providerPlan, newlyAssignedSeatTotal); @@ -423,9 +422,7 @@ public class ProviderBillingService( else if (currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal > seatMinimum) { - await update( - seatMinimum, - newlyAssignedSeatTotal); + await scaleQuantityTo(newlyAssignedSeatTotal); } /* * Above the limit => Above the limit: @@ -434,9 +431,7 @@ public class ProviderBillingService( else if (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum) { - await update( - currentlyAssignedSeatTotal, - newlyAssignedSeatTotal); + await scaleQuantityTo(newlyAssignedSeatTotal); } /* * Above the limit => Below the limit: @@ -445,9 +440,7 @@ public class ProviderBillingService( else if (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal <= seatMinimum) { - await update( - currentlyAssignedSeatTotal, - seatMinimum); + await scaleQuantityTo(seatMinimum); } } @@ -557,7 +550,8 @@ public class ProviderBillingService( { ArgumentNullException.ThrowIfNull(provider); - var customer = await subscriberService.GetCustomerOrThrow(provider); + var customerGetOptions = new CustomerGetOptions { Expand = ["tax", "tax_ids"] }; + var customer = await subscriberService.GetCustomerOrThrow(provider, customerGetOptions); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); @@ -580,19 +574,17 @@ public class ProviderBillingService( throw new BillingException(); } + var priceId = ProviderPriceAdapter.GetActivePriceId(provider, providerPlan.PlanType); + subscriptionItemOptionsList.Add(new SubscriptionItemOptions { - Price = plan.PasswordManager.StripeProviderPortalSeatPlanId, + Price = priceId, Quantity = providerPlan.SeatMinimum }); } var subscriptionCreateOptions = new SubscriptionCreateOptions { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }, CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, Customer = customer.Id, DaysUntilDue = 30, @@ -605,6 +597,15 @@ public class ProviderBillingService( ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations }; + if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + } + else + { + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; + } + try { var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); @@ -643,43 +644,37 @@ public class ProviderBillingService( public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command) { - if (command.Configuration.Any(x => x.SeatsMinimum < 0)) + var (provider, updatedPlanConfigurations) = command; + + if (updatedPlanConfigurations.Any(x => x.SeatsMinimum < 0)) { throw new BadRequestException("Provider seat minimums must be at least 0."); } - Subscription subscription; - try - { - subscription = await stripeAdapter.ProviderSubscriptionGetAsync(command.GatewaySubscriptionId, command.Id); - } - catch (InvalidOperationException) - { - throw new ConflictException("Subscription not found."); - } + var subscription = await subscriberService.GetSubscriptionOrThrow(provider); var subscriptionItemOptionsList = new List(); - var providerPlans = await providerPlanRepository.GetByProviderId(command.Id); + var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - foreach (var newPlanConfiguration in command.Configuration) + foreach (var updatedPlanConfiguration in updatedPlanConfigurations) { + var (updatedPlanType, updatedSeatMinimum) = updatedPlanConfiguration; + var providerPlan = - providerPlans.Single(providerPlan => providerPlan.PlanType == newPlanConfiguration.Plan); + providerPlans.Single(providerPlan => providerPlan.PlanType == updatedPlanType); - if (providerPlan.SeatMinimum != newPlanConfiguration.SeatsMinimum) + if (providerPlan.SeatMinimum != updatedSeatMinimum) { - var newPlan = await pricingClient.GetPlanOrThrow(newPlanConfiguration.Plan); - - var priceId = newPlan.PasswordManager.StripeProviderPortalSeatPlanId; + var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, updatedPlanType); var subscriptionItem = subscription.Items.First(item => item.Price.Id == priceId); if (providerPlan.PurchasedSeats == 0) { - if (providerPlan.AllocatedSeats > newPlanConfiguration.SeatsMinimum) + if (providerPlan.AllocatedSeats > updatedSeatMinimum) { - providerPlan.PurchasedSeats = providerPlan.AllocatedSeats - newPlanConfiguration.SeatsMinimum; + providerPlan.PurchasedSeats = providerPlan.AllocatedSeats - updatedSeatMinimum; subscriptionItemOptionsList.Add(new SubscriptionItemOptions { @@ -694,7 +689,7 @@ public class ProviderBillingService( { Id = subscriptionItem.Id, Price = priceId, - Quantity = newPlanConfiguration.SeatsMinimum + Quantity = updatedSeatMinimum }); } } @@ -702,9 +697,9 @@ public class ProviderBillingService( { var totalSeats = providerPlan.SeatMinimum + providerPlan.PurchasedSeats; - if (newPlanConfiguration.SeatsMinimum <= totalSeats) + if (updatedSeatMinimum <= totalSeats) { - providerPlan.PurchasedSeats = totalSeats - newPlanConfiguration.SeatsMinimum; + providerPlan.PurchasedSeats = totalSeats - updatedSeatMinimum; } else { @@ -713,12 +708,12 @@ public class ProviderBillingService( { Id = subscriptionItem.Id, Price = priceId, - Quantity = newPlanConfiguration.SeatsMinimum + Quantity = updatedSeatMinimum }); } } - providerPlan.SeatMinimum = newPlanConfiguration.SeatsMinimum; + providerPlan.SeatMinimum = updatedSeatMinimum; await providerPlanRepository.ReplaceAsync(providerPlan); } @@ -726,23 +721,33 @@ public class ProviderBillingService( if (subscriptionItemOptionsList.Count > 0) { - await stripeAdapter.SubscriptionUpdateAsync(command.GatewaySubscriptionId, + await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { Items = subscriptionItemOptionsList }); } } - private Func CurrySeatScalingUpdate( + private Func CurrySeatScalingUpdate( Provider provider, ProviderPlan providerPlan, - int newlyAssignedSeats) => async (currentlySubscribedSeats, newlySubscribedSeats) => + int newlyAssignedSeats) => async newlySubscribedSeats => { - var plan = await pricingClient.GetPlanOrThrow(providerPlan.PlanType); + var subscription = await subscriberService.GetSubscriptionOrThrow(provider); - await paymentService.AdjustSeats( - provider, - plan, - currentlySubscribedSeats, - newlySubscribedSeats); + var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType); + + var item = subscription.Items.First(item => item.Price.Id == priceId); + + await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions + { + Items = [ + new SubscriptionItemOptions + { + Id = item.Id, + Price = priceId, + Quantity = newlySubscribedSeats + } + ] + }); var newlyPurchasedSeats = newlySubscribedSeats > providerPlan.SeatMinimum ? newlySubscribedSeats - providerPlan.SeatMinimum diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderPriceAdapter.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderPriceAdapter.cs new file mode 100644 index 0000000000..4cc0711ec9 --- /dev/null +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderPriceAdapter.cs @@ -0,0 +1,133 @@ +// ReSharper disable SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault +#nullable enable +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing; +using Bit.Core.Billing.Enums; +using Stripe; + +namespace Bit.Commercial.Core.Billing; + +public static class ProviderPriceAdapter +{ + public static class MSP + { + public static class Active + { + public const string Enterprise = "provider-portal-enterprise-monthly-2025"; + public const string Teams = "provider-portal-teams-monthly-2025"; + } + + public static class Legacy + { + public const string Enterprise = "password-manager-provider-portal-enterprise-monthly-2024"; + public const string Teams = "password-manager-provider-portal-teams-monthly-2024"; + public static readonly List List = [Enterprise, Teams]; + } + } + + public static class BusinessUnit + { + public static class Active + { + public const string Annually = "business-unit-portal-enterprise-annually-2025"; + public const string Monthly = "business-unit-portal-enterprise-monthly-2025"; + } + + public static class Legacy + { + public const string Annually = "password-manager-provider-portal-enterprise-annually-2024"; + public const string Monthly = "password-manager-provider-portal-enterprise-monthly-2024"; + public static readonly List List = [Annually, Monthly]; + } + } + + /// + /// Uses the 's and to determine + /// whether the is on active or legacy pricing and then returns a Stripe price ID for the provided + /// based on that determination. + /// + /// The provider to get the Stripe price ID for. + /// The provider's subscription. + /// The plan type correlating to the desired Stripe price ID. + /// A Stripe ID. + /// Thrown when the provider's type is not or . + /// Thrown when the provided does not relate to a Stripe price ID. + public static string GetPriceId( + Provider provider, + Subscription subscription, + PlanType planType) + { + var priceIds = subscription.Items.Select(item => item.Price.Id); + + var invalidPlanType = + new BillingException(message: $"PlanType {planType} does not have an associated provider price in Stripe"); + + return provider.Type switch + { + ProviderType.Msp => MSP.Legacy.List.Intersect(priceIds).Any() + ? planType switch + { + PlanType.TeamsMonthly => MSP.Legacy.Teams, + PlanType.EnterpriseMonthly => MSP.Legacy.Enterprise, + _ => throw invalidPlanType + } + : planType switch + { + PlanType.TeamsMonthly => MSP.Active.Teams, + PlanType.EnterpriseMonthly => MSP.Active.Enterprise, + _ => throw invalidPlanType + }, + ProviderType.MultiOrganizationEnterprise => BusinessUnit.Legacy.List.Intersect(priceIds).Any() + ? planType switch + { + PlanType.EnterpriseAnnually => BusinessUnit.Legacy.Annually, + PlanType.EnterpriseMonthly => BusinessUnit.Legacy.Monthly, + _ => throw invalidPlanType + } + : planType switch + { + PlanType.EnterpriseAnnually => BusinessUnit.Active.Annually, + PlanType.EnterpriseMonthly => BusinessUnit.Active.Monthly, + _ => throw invalidPlanType + }, + _ => throw new BillingException( + $"ProviderType {provider.Type} does not have any associated provider price IDs") + }; + } + + /// + /// Uses the 's to return the active Stripe price ID for the provided + /// . + /// + /// The provider to get the Stripe price ID for. + /// The plan type correlating to the desired Stripe price ID. + /// A Stripe ID. + /// Thrown when the provider's type is not or . + /// Thrown when the provided does not relate to a Stripe price ID. + public static string GetActivePriceId( + Provider provider, + PlanType planType) + { + var invalidPlanType = + new BillingException(message: $"PlanType {planType} does not have an associated provider price in Stripe"); + + return provider.Type switch + { + ProviderType.Msp => planType switch + { + PlanType.TeamsMonthly => MSP.Active.Teams, + PlanType.EnterpriseMonthly => MSP.Active.Enterprise, + _ => throw invalidPlanType + }, + ProviderType.MultiOrganizationEnterprise => planType switch + { + PlanType.EnterpriseAnnually => BusinessUnit.Active.Annually, + PlanType.EnterpriseMonthly => BusinessUnit.Active.Monthly, + _ => throw invalidPlanType + }, + _ => throw new BillingException( + $"ProviderType {provider.Type} does not have any associated provider price IDs") + }; + } +} diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index 2debd521a5..48eda094e8 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -228,6 +228,26 @@ public class RemoveOrganizationFromProviderCommandTests Id = "subscription_id" }); + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( + Arg.Is(options => + options.Customer == organization.GatewayCustomerId && + options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && + options.DaysUntilDue == 30 && + options.Metadata["organizationId"] == organization.Id.ToString() && + options.OffSession == true && + options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId && + options.Items.First().Quantity == organization.Seats) + , Arg.Any())) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index c1da732d60..ab1000d631 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -4,6 +4,7 @@ using Bit.Commercial.Core.Billing; using Bit.Commercial.Core.Billing.Models; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; @@ -115,6 +116,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.MultiOrganizationEnterprise; + var providerPlanRepository = sutProvider.GetDependency(); var existingPlan = new ProviderPlan { @@ -132,10 +135,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetPlanOrThrow(existingPlan.PlanType) .Returns(StaticStore.GetPlan(existingPlan.PlanType)); - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.ProviderSubscriptionGetAsync( - Arg.Is(provider.GatewaySubscriptionId), - Arg.Is(provider.Id)) + sutProvider.GetDependency().GetSubscriptionOrThrow(provider) .Returns(new Subscription { Id = provider.GatewaySubscriptionId, @@ -158,7 +158,7 @@ public class ProviderBillingServiceTests }); var command = - new ChangeProviderPlanCommand(providerPlanId, PlanType.EnterpriseMonthly, provider.GatewaySubscriptionId); + new ChangeProviderPlanCommand(provider, providerPlanId, PlanType.EnterpriseMonthly); sutProvider.GetDependency().GetPlanOrThrow(command.NewPlan) .Returns(StaticStore.GetPlan(command.NewPlan)); @@ -170,6 +170,8 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1) .ReplaceAsync(Arg.Is(p => p.PlanType == PlanType.EnterpriseMonthly)); + var stripeAdapter = sutProvider.GetDependency(); + await stripeAdapter.Received(1) .SubscriptionUpdateAsync( Arg.Is(provider.GatewaySubscriptionId), @@ -405,6 +407,23 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } }, + new SubscriptionItem + { + Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise } + } + ] + } + }; + + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); + // 50 seats currently assigned with a seat minimum of 100 var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); @@ -427,11 +446,9 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().AdjustSeats( - Arg.Any(), - Arg.Any(), - Arg.Any(), - Arg.Any()); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().SubscriptionUpdateAsync( + Arg.Any(), + Arg.Any()); await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( pPlan => pPlan.AllocatedSeats == 60)); @@ -474,6 +491,23 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } }, + new SubscriptionItem + { + Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise } + } + ] + } + }; + + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); + // 95 seats currently assigned with a seat minimum of 100 var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); @@ -496,11 +530,12 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 95 current + 10 seat scale = 105 seats, 5 above the minimum - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - providerPlan.SeatMinimum!.Value, - 105); + await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + provider.GatewaySubscriptionId, + Arg.Is( + options => + options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams && + options.Items.First().Quantity == 105)); // 105 total seats - 100 minimum = 5 purchased seats await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( @@ -544,6 +579,23 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } }, + new SubscriptionItem + { + Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise } + } + ] + } + }; + + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); + // 110 seats currently assigned with a seat minimum of 100 var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); @@ -566,11 +618,12 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 110 current + 10 seat scale up = 120 seats - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - 110, - 120); + await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + provider.GatewaySubscriptionId, + Arg.Is( + options => + options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams && + options.Items.First().Quantity == 120)); // 120 total seats - 100 seat minimum = 20 purchased seats await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( @@ -614,6 +667,23 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } }, + new SubscriptionItem + { + Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise } + } + ] + } + }; + + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); + // 110 seats currently assigned with a seat minimum of 100 var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); @@ -636,11 +706,12 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, -30); // 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum. - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - 110, - providerPlan.SeatMinimum!.Value); + await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + provider.GatewaySubscriptionId, + Arg.Is( + options => + options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams && + options.Items.First().Quantity == providerPlan.SeatMinimum!.Value)); // Being below the seat minimum means no purchased seats. await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( @@ -924,11 +995,15 @@ public class ProviderBillingServiceTests { provider.GatewaySubscriptionId = null; - sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(new Customer - { - Id = "customer_id", - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } - }); + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) + .Returns(new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }); var providerPlans = new List { @@ -973,13 +1048,18 @@ public class ProviderBillingServiceTests SutProvider sutProvider, Provider provider) { + provider.Type = ProviderType.Msp; provider.GatewaySubscriptionId = null; - sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(new Customer + var customer = new Customer { Id = "customer_id", Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } - }); + }; + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer); var providerPlans = new List { @@ -1012,11 +1092,21 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id) .Returns(providerPlans); - var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); - var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( + Arg.Is(options => + options.Customer == "customer_id") + , Arg.Is(p => p == customer))) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && @@ -1024,9 +1114,9 @@ public class ProviderBillingServiceTests sub.Customer == "customer_id" && sub.DaysUntilDue == 30 && sub.Items.Count == 2 && - sub.Items.ElementAt(0).Price == teamsPlan.PasswordManager.StripeProviderPortalSeatPlanId && + sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams && sub.Items.ElementAt(0).Quantity == 100 && - sub.Items.ElementAt(1).Price == enterprisePlan.PasswordManager.StripeProviderPortalSeatPlanId && + sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise && sub.Items.ElementAt(1).Quantity == 100 && sub.Metadata["providerId"] == provider.Id.ToString() && sub.OffSession == true && @@ -1048,8 +1138,7 @@ public class ProviderBillingServiceTests { // Arrange var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.TeamsMonthly, -10), (PlanType.EnterpriseMonthly, 50) @@ -1068,6 +1157,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.Msp; + var stripeAdapter = sutProvider.GetDependency(); var providerPlanRepository = sutProvider.GetDependency(); @@ -1097,9 +1188,7 @@ public class ProviderBillingServiceTests } }; - stripeAdapter.ProviderSubscriptionGetAsync( - provider.GatewaySubscriptionId, - provider.Id).Returns(subscription); + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); var providerPlans = new List { @@ -1116,8 +1205,7 @@ public class ProviderBillingServiceTests providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.EnterpriseMonthly, 30), (PlanType.TeamsMonthly, 20) @@ -1149,6 +1237,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.Msp; + var stripeAdapter = sutProvider.GetDependency(); var providerPlanRepository = sutProvider.GetDependency(); @@ -1178,7 +1268,7 @@ public class ProviderBillingServiceTests } }; - stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); var providerPlans = new List { @@ -1195,8 +1285,7 @@ public class ProviderBillingServiceTests providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.EnterpriseMonthly, 70), (PlanType.TeamsMonthly, 50) @@ -1228,6 +1317,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.Msp; + var stripeAdapter = sutProvider.GetDependency(); var providerPlanRepository = sutProvider.GetDependency(); @@ -1257,7 +1348,7 @@ public class ProviderBillingServiceTests } }; - stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); var providerPlans = new List { @@ -1274,8 +1365,7 @@ public class ProviderBillingServiceTests providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.EnterpriseMonthly, 60), (PlanType.TeamsMonthly, 60) @@ -1301,6 +1391,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.Msp; + var stripeAdapter = sutProvider.GetDependency(); var providerPlanRepository = sutProvider.GetDependency(); @@ -1330,7 +1422,7 @@ public class ProviderBillingServiceTests } }; - stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); var providerPlans = new List { @@ -1347,8 +1439,7 @@ public class ProviderBillingServiceTests providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.EnterpriseMonthly, 80), (PlanType.TeamsMonthly, 80) @@ -1380,6 +1471,8 @@ public class ProviderBillingServiceTests SutProvider sutProvider) { // Arrange + provider.Type = ProviderType.Msp; + var stripeAdapter = sutProvider.GetDependency(); var providerPlanRepository = sutProvider.GetDependency(); @@ -1409,7 +1502,7 @@ public class ProviderBillingServiceTests } }; - stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); + sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); var providerPlans = new List { @@ -1426,8 +1519,7 @@ public class ProviderBillingServiceTests providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); var command = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (PlanType.EnterpriseMonthly, 70), (PlanType.TeamsMonthly, 30) diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderPriceAdapterTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderPriceAdapterTests.cs new file mode 100644 index 0000000000..4fce78c05a --- /dev/null +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderPriceAdapterTests.cs @@ -0,0 +1,151 @@ +using Bit.Commercial.Core.Billing; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Enums; +using Stripe; +using Xunit; + +namespace Bit.Commercial.Core.Test.Billing; + +public class ProviderPriceAdapterTests +{ + [Theory] + [InlineData("password-manager-provider-portal-enterprise-monthly-2024", PlanType.EnterpriseMonthly)] + [InlineData("password-manager-provider-portal-teams-monthly-2024", PlanType.TeamsMonthly)] + public void GetPriceId_MSP_Legacy_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.Msp + }; + + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = priceId } } + ] + } + }; + + var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType); + + Assert.Equal(result, priceId); + } + + [Theory] + [InlineData("provider-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)] + [InlineData("provider-portal-teams-monthly-2025", PlanType.TeamsMonthly)] + public void GetPriceId_MSP_Active_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.Msp + }; + + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = priceId } } + ] + } + }; + + var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType); + + Assert.Equal(result, priceId); + } + + [Theory] + [InlineData("password-manager-provider-portal-enterprise-annually-2024", PlanType.EnterpriseAnnually)] + [InlineData("password-manager-provider-portal-enterprise-monthly-2024", PlanType.EnterpriseMonthly)] + public void GetPriceId_BusinessUnit_Legacy_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.MultiOrganizationEnterprise + }; + + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = priceId } } + ] + } + }; + + var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType); + + Assert.Equal(result, priceId); + } + + [Theory] + [InlineData("business-unit-portal-enterprise-annually-2025", PlanType.EnterpriseAnnually)] + [InlineData("business-unit-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)] + public void GetPriceId_BusinessUnit_Active_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.MultiOrganizationEnterprise + }; + + var subscription = new Subscription + { + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = priceId } } + ] + } + }; + + var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType); + + Assert.Equal(result, priceId); + } + + [Theory] + [InlineData("provider-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)] + [InlineData("provider-portal-teams-monthly-2025", PlanType.TeamsMonthly)] + public void GetActivePriceId_MSP_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.Msp + }; + + var result = ProviderPriceAdapter.GetActivePriceId(provider, planType); + + Assert.Equal(result, priceId); + } + + [Theory] + [InlineData("business-unit-portal-enterprise-annually-2025", PlanType.EnterpriseAnnually)] + [InlineData("business-unit-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)] + public void GetActivePriceId_BusinessUnit_Succeeds(string priceId, PlanType planType) + { + var provider = new Provider + { + Id = Guid.NewGuid(), + Type = ProviderType.MultiOrganizationEnterprise + }; + + var result = ProviderPriceAdapter.GetActivePriceId(provider, planType); + + Assert.Equal(result, priceId); + } +} diff --git a/src/Admin/AdminConsole/Controllers/ProvidersController.cs b/src/Admin/AdminConsole/Controllers/ProvidersController.cs index c38bb64419..0b1e4035df 100644 --- a/src/Admin/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Admin/AdminConsole/Controllers/ProvidersController.cs @@ -300,8 +300,7 @@ public class ProvidersController : Controller { case ProviderType.Msp: var updateMspSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (Plan: PlanType.TeamsMonthly, SeatsMinimum: model.TeamsMonthlySeatMinimum), (Plan: PlanType.EnterpriseMonthly, SeatsMinimum: model.EnterpriseMonthlySeatMinimum) @@ -314,15 +313,14 @@ public class ProvidersController : Controller // 1. Change the plan and take over any old values. var changeMoePlanCommand = new ChangeProviderPlanCommand( + provider, existingMoePlan.Id, - model.Plan!.Value, - provider.GatewaySubscriptionId); + model.Plan!.Value); await _providerBillingService.ChangePlan(changeMoePlanCommand); // 2. Update the seat minimums. var updateMoeSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand( - provider.Id, - provider.GatewaySubscriptionId, + provider, [ (Plan: model.Plan!.Value, SeatsMinimum: model.EnterpriseMinimumSeats!.Value) ]); diff --git a/src/Api/Controllers/DevicesController.cs b/src/Api/Controllers/DevicesController.cs index 02eb2d36d5..4e21b5e9dc 100644 --- a/src/Api/Controllers/DevicesController.cs +++ b/src/Api/Controllers/DevicesController.cs @@ -1,6 +1,5 @@ using System.ComponentModel.DataAnnotations; using Bit.Api.Auth.Models.Request; -using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Models.Request; using Bit.Api.Models.Response; using Bit.Core.Auth.Models.Api.Request; @@ -125,7 +124,7 @@ public class DevicesController : Controller } [HttpPost("{identifier}/retrieve-keys")] - public async Task GetDeviceKeys(string identifier, [FromBody] SecretVerificationRequestModel model) + public async Task GetDeviceKeys(string identifier) { var user = await _userService.GetUserByPrincipalAsync(User); @@ -134,14 +133,7 @@ public class DevicesController : Controller throw new UnauthorizedAccessException(); } - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "User verification failed."); - } - var device = await _deviceRepository.GetByIdentifierAsync(identifier, user.Id); - if (device == null) { throw new NotFoundException(); diff --git a/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs b/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs index 85e0981f22..0764e2ee28 100644 --- a/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs +++ b/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs @@ -8,6 +8,7 @@ using Bit.Api.Tools.Models.Request; using Bit.Api.Vault.Models.Request; using Bit.Core; using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Models.Api.Request; using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; using Bit.Core.Exceptions; @@ -43,6 +44,7 @@ public class AccountsKeyManagementController : Controller _organizationUserValidator; private readonly IRotationValidator, IEnumerable> _webauthnKeyValidator; + private readonly IRotationValidator, IEnumerable> _deviceValidator; public AccountsKeyManagementController(IUserService userService, IFeatureService featureService, @@ -57,7 +59,8 @@ public class AccountsKeyManagementController : Controller emergencyAccessValidator, IRotationValidator, IReadOnlyList> organizationUserValidator, - IRotationValidator, IEnumerable> webAuthnKeyValidator) + IRotationValidator, IEnumerable> webAuthnKeyValidator, + IRotationValidator, IEnumerable> deviceValidator) { _userService = userService; _featureService = featureService; @@ -71,6 +74,7 @@ public class AccountsKeyManagementController : Controller _emergencyAccessValidator = emergencyAccessValidator; _organizationUserValidator = organizationUserValidator; _webauthnKeyValidator = webAuthnKeyValidator; + _deviceValidator = deviceValidator; } [HttpPost("regenerate-keys")] @@ -109,6 +113,7 @@ public class AccountsKeyManagementController : Controller EmergencyAccesses = await _emergencyAccessValidator.ValidateAsync(user, model.AccountUnlockData.EmergencyAccessUnlockData), OrganizationUsers = await _organizationUserValidator.ValidateAsync(user, model.AccountUnlockData.OrganizationAccountRecoveryUnlockData), WebAuthnKeys = await _webauthnKeyValidator.ValidateAsync(user, model.AccountUnlockData.PasskeyUnlockData), + DeviceKeys = await _deviceValidator.ValidateAsync(user, model.AccountUnlockData.DeviceKeyUnlockData), Ciphers = await _cipherValidator.ValidateAsync(user, model.AccountData.Ciphers), Folders = await _folderValidator.ValidateAsync(user, model.AccountData.Folders), diff --git a/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs b/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs index 5156e2a655..23c3eb95d0 100644 --- a/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs @@ -3,6 +3,7 @@ using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.Auth.Models.Request; using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Auth.Models.Request.WebAuthn; +using Bit.Core.Auth.Models.Api.Request; namespace Bit.Api.KeyManagement.Models.Requests; @@ -13,4 +14,5 @@ public class UnlockDataRequestModel public required IEnumerable EmergencyAccessUnlockData { get; set; } public required IEnumerable OrganizationAccountRecoveryUnlockData { get; set; } public required IEnumerable PasskeyUnlockData { get; set; } + public required IEnumerable DeviceKeyUnlockData { get; set; } } diff --git a/src/Api/KeyManagement/Validators/DeviceRotationValidator.cs b/src/Api/KeyManagement/Validators/DeviceRotationValidator.cs new file mode 100644 index 0000000000..cbaf508766 --- /dev/null +++ b/src/Api/KeyManagement/Validators/DeviceRotationValidator.cs @@ -0,0 +1,53 @@ +using Bit.Core.Auth.Models.Api.Request; +using Bit.Core.Auth.Utilities; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; + +namespace Bit.Api.KeyManagement.Validators; + +/// +/// Device implementation for +/// +public class DeviceRotationValidator : IRotationValidator, IEnumerable> +{ + private readonly IDeviceRepository _deviceRepository; + + /// + /// Instantiates a new + /// + /// Retrieves all user s + public DeviceRotationValidator(IDeviceRepository deviceRepository) + { + _deviceRepository = deviceRepository; + } + + public async Task> ValidateAsync(User user, IEnumerable devices) + { + var result = new List(); + + var existingTrustedDevices = (await _deviceRepository.GetManyByUserIdAsync(user.Id)).Where(d => d.IsTrusted()).ToList(); + if (existingTrustedDevices.Count == 0) + { + return result; + } + + foreach (var existing in existingTrustedDevices) + { + var device = devices.FirstOrDefault(c => c.DeviceId == existing.Id); + if (device == null) + { + throw new BadRequestException("All existing trusted devices must be included in the rotation."); + } + + if (device.EncryptedUserKey == null || device.EncryptedPublicKey == null) + { + throw new BadRequestException("Rotated encryption keys must be provided for all devices that are trusted."); + } + + result.Add(device.ToDevice(existing)); + } + + return result; + } +} diff --git a/src/Api/NotificationCenter/Models/Response/NotificationResponseModel.cs b/src/Api/NotificationCenter/Models/Response/NotificationResponseModel.cs index 1ebed87de2..ab882d5557 100644 --- a/src/Api/NotificationCenter/Models/Response/NotificationResponseModel.cs +++ b/src/Api/NotificationCenter/Models/Response/NotificationResponseModel.cs @@ -22,6 +22,7 @@ public class NotificationResponseModel : ResponseModel Title = notificationStatusDetails.Title; Body = notificationStatusDetails.Body; Date = notificationStatusDetails.RevisionDate; + TaskId = notificationStatusDetails.TaskId; ReadDate = notificationStatusDetails.ReadDate; DeletedDate = notificationStatusDetails.DeletedDate; } @@ -40,6 +41,8 @@ public class NotificationResponseModel : ResponseModel public DateTime Date { get; set; } + public Guid? TaskId { get; set; } + public DateTime? ReadDate { get; set; } public DateTime? DeletedDate { get; set; } diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index e4c8ed7db7..cc1e533ffd 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -33,7 +33,7 @@ using Bit.Core.Auth.Identity.TokenProviders; using Bit.Core.Services; using Bit.Core.Tools.ImportFeatures; using Bit.Core.Tools.ReportFeatures; - +using Bit.Core.Auth.Models.Api.Request; #if !OSS using Bit.Commercial.Core.SecretsManager; @@ -170,6 +170,9 @@ public class Startup services .AddScoped, IEnumerable>, WebAuthnLoginKeyRotationValidator>(); + services + .AddScoped, IEnumerable>, + DeviceRotationValidator>(); // Services services.AddBaseServices(globalSettings); diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index d37bf41428..f75cbf8a8b 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -1,8 +1,11 @@ -using Bit.Core.AdminConsole.Repositories; +using Bit.Core; +using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -12,6 +15,7 @@ using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; public class UpcomingInvoiceHandler( + IFeatureService featureService, ILogger logger, IMailService mailService, IOrganizationRepository organizationRepository, @@ -21,7 +25,8 @@ public class UpcomingInvoiceHandler( IStripeEventService stripeEventService, IStripeEventUtilityService stripeEventUtilityService, IUserRepository userRepository, - IValidateSponsorshipCommand validateSponsorshipCommand) + IValidateSponsorshipCommand validateSponsorshipCommand, + IAutomaticTaxFactory automaticTaxFactory) : IUpcomingInvoiceHandler { public async Task HandleAsync(Event parsedEvent) @@ -136,6 +141,21 @@ public class UpcomingInvoiceHandler( private async Task TryEnableAutomaticTaxAsync(Subscription subscription) { + if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id)); + var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); + var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription); + + if (updateOptions == null) + { + return; + } + + await stripeFacade.UpdateSubscription(subscription.Id, updateOptions); + return; + } + if (subscription.AutomaticTax.Enabled || !subscription.Customer.HasBillingLocation() || await IsNonTaxableNonUSBusinessUseSubscription(subscription)) diff --git a/src/Core/AdminConsole/Entities/Organization.cs b/src/Core/AdminConsole/Entities/Organization.cs index 54661e22a7..e91f1ede29 100644 --- a/src/Core/AdminConsole/Entities/Organization.cs +++ b/src/Core/AdminConsole/Entities/Organization.cs @@ -313,5 +313,6 @@ public class Organization : ITableObject, IStorableSubscriber, IRevisable, UseSecretsManager = license.UseSecretsManager; SmSeats = license.SmSeats; SmServiceAccounts = license.SmServiceAccounts; + UseRiskInsights = license.UseRiskInsights; } } diff --git a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs index c53ac8745c..ab2dfd7e0e 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs @@ -148,7 +148,8 @@ public class SelfHostedOrganizationDetails : Organization LimitCollectionDeletion = LimitCollectionDeletion, LimitItemDeletion = LimitItemDeletion, AllowAdminAccessToAllCollectionItems = AllowAdminAccessToAllCollectionItems, - Status = Status + Status = Status, + UseRiskInsights = UseRiskInsights, }; } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs index 3d4b0fba5c..f122463a98 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs @@ -87,7 +87,10 @@ public class RestoreOrganizationUserCommand( .twoFactorIsEnabled; } - await CheckUserForOtherFreeOrganizationOwnershipAsync(organizationUser); + if (organization.PlanType == PlanType.Free) + { + await CheckUserForOtherFreeOrganizationOwnershipAsync(organizationUser); + } await CheckPoliciesBeforeRestoreAsync(organizationUser, userTwoFactorIsEnabled); @@ -100,7 +103,7 @@ public class RestoreOrganizationUserCommand( private async Task CheckUserForOtherFreeOrganizationOwnershipAsync(OrganizationUser organizationUser) { - var relatedOrgUsersFromOtherOrgs = await organizationUserRepository.GetManyByUserAsync(organizationUser.UserId.Value); + var relatedOrgUsersFromOtherOrgs = await organizationUserRepository.GetManyByUserAsync(organizationUser.UserId!.Value); var otherOrgs = await organizationRepository.GetManyByUserIdAsync(organizationUser.UserId.Value); var orgOrgUserDict = relatedOrgUsersFromOtherOrgs @@ -110,13 +113,16 @@ public class RestoreOrganizationUserCommand( CheckForOtherFreeOrganizationOwnership(organizationUser, orgOrgUserDict); } - private async Task> GetRelatedOrganizationUsersAndOrganizations( - IEnumerable organizationUsers) + private async Task> GetRelatedOrganizationUsersAndOrganizationsAsync( + List organizationUsers) { - var allUserIds = organizationUsers.Select(x => x.UserId.Value); + var allUserIds = organizationUsers + .Where(x => x.UserId.HasValue) + .Select(x => x.UserId.Value); var otherOrganizationUsers = (await organizationUserRepository.GetManyByManyUsersAsync(allUserIds)) - .Where(x => organizationUsers.Any(y => y.Id == x.Id) == false); + .Where(x => organizationUsers.Any(y => y.Id == x.Id) == false) + .ToArray(); var otherOrgs = await organizationRepository.GetManyByIdsAsync(otherOrganizationUsers .Select(x => x.OrganizationId) @@ -130,7 +136,9 @@ public class RestoreOrganizationUserCommand( Dictionary otherOrgUsersAndOrgs) { var ownerOrAdminList = new[] { OrganizationUserType.Owner, OrganizationUserType.Admin }; - if (otherOrgUsersAndOrgs.Any(x => + + if (ownerOrAdminList.Any(x => organizationUser.Type == x) && + otherOrgUsersAndOrgs.Any(x => x.Key.UserId == organizationUser.UserId && ownerOrAdminList.Any(userType => userType == x.Key.Type) && x.Key.Status == OrganizationUserStatusType.Confirmed && @@ -170,7 +178,7 @@ public class RestoreOrganizationUserCommand( var organizationUsersTwoFactorEnabled = await twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync( filteredUsers.Where(ou => ou.UserId.HasValue).Select(ou => ou.UserId.Value)); - var orgUsersAndOrgs = await GetRelatedOrganizationUsersAndOrganizations(filteredUsers); + var orgUsersAndOrgs = await GetRelatedOrganizationUsersAndOrganizationsAsync(filteredUsers); var result = new List>(); @@ -201,7 +209,10 @@ public class RestoreOrganizationUserCommand( await CheckPoliciesBeforeRestoreAsync(organizationUser, twoFactorIsEnabled); - CheckForOtherFreeOrganizationOwnership(organizationUser, orgUsersAndOrgs); + if (organization.PlanType == PlanType.Free) + { + CheckForOtherFreeOrganizationOwnership(organizationUser, orgUsersAndOrgs); + } var status = OrganizationService.GetPriorActiveOrganizationUserStatusType(organizationUser); diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/ResetPasswordPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/ResetPasswordPolicyRequirement.cs index 4feef1b088..b7d0b14f15 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/ResetPasswordPolicyRequirement.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/ResetPasswordPolicyRequirement.cs @@ -34,6 +34,8 @@ public class ResetPasswordPolicyRequirementFactory : BasePolicyRequirementFactor protected override IEnumerable ExemptRoles => []; + protected override IEnumerable ExemptStatuses => [OrganizationUserStatusType.Revoked]; + public override ResetPasswordPolicyRequirement Create(IEnumerable policyDetails) { var result = policyDetails diff --git a/src/Core/Auth/Models/Api/Request/DeviceKeysUpdateRequestModel.cs b/src/Core/Auth/Models/Api/Request/DeviceKeysUpdateRequestModel.cs index 2b815afd16..111b03a3a3 100644 --- a/src/Core/Auth/Models/Api/Request/DeviceKeysUpdateRequestModel.cs +++ b/src/Core/Auth/Models/Api/Request/DeviceKeysUpdateRequestModel.cs @@ -1,4 +1,5 @@ using System.ComponentModel.DataAnnotations; +using Bit.Core.Entities; using Bit.Core.Utilities; namespace Bit.Core.Auth.Models.Api.Request; @@ -7,6 +8,13 @@ public class OtherDeviceKeysUpdateRequestModel : DeviceKeysUpdateRequestModel { [Required] public Guid DeviceId { get; set; } + + public Device ToDevice(Device existingDevice) + { + existingDevice.EncryptedPublicKey = EncryptedPublicKey; + existingDevice.EncryptedUserKey = EncryptedUserKey; + return existingDevice; + } } public class DeviceKeysUpdateRequestModel diff --git a/src/Core/Auth/Models/Api/Response/DeviceAuthRequestResponseModel.cs b/src/Core/Auth/Models/Api/Response/DeviceAuthRequestResponseModel.cs index 3cfea51ee3..59630a6d2c 100644 --- a/src/Core/Auth/Models/Api/Response/DeviceAuthRequestResponseModel.cs +++ b/src/Core/Auth/Models/Api/Response/DeviceAuthRequestResponseModel.cs @@ -1,5 +1,4 @@ using Bit.Core.Auth.Models.Data; -using Bit.Core.Auth.Utilities; using Bit.Core.Enums; using Bit.Core.Models.Api; @@ -19,7 +18,7 @@ public class DeviceAuthRequestResponseModel : ResponseModel Type = deviceAuthDetails.Type, Identifier = deviceAuthDetails.Identifier, CreationDate = deviceAuthDetails.CreationDate, - IsTrusted = deviceAuthDetails.IsTrusted() + IsTrusted = deviceAuthDetails.IsTrusted, }; if (deviceAuthDetails.AuthRequestId != null && deviceAuthDetails.AuthRequestCreatedAt != null) diff --git a/src/Core/Auth/Services/Implementations/AuthRequestService.cs b/src/Core/Auth/Services/Implementations/AuthRequestService.cs index 42d51a88f5..0fd1846d00 100644 --- a/src/Core/Auth/Services/Implementations/AuthRequestService.cs +++ b/src/Core/Auth/Services/Implementations/AuthRequestService.cs @@ -289,6 +289,12 @@ public class AuthRequestService : IAuthRequestService { var adminEmails = await GetAdminAndAccountRecoveryEmailsAsync(organizationUser.OrganizationId); + if (adminEmails.Count == 0) + { + _logger.LogWarning("There are no admin emails to send to."); + return; + } + await _mailService.SendDeviceApprovalRequestedNotificationEmailAsync( adminEmails, organizationUser.OrganizationId, diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index 080416e2bb..326023e34c 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -47,6 +47,8 @@ public static class StripeConstants public static class MetadataKeys { public const string OrganizationId = "organizationId"; + public const string ProviderId = "providerId"; + public const string UserId = "userId"; } public static class PaymentBehavior diff --git a/src/Core/Billing/Extensions/CustomerExtensions.cs b/src/Core/Billing/Extensions/CustomerExtensions.cs index 1ab595342e..8f15f61a7f 100644 --- a/src/Core/Billing/Extensions/CustomerExtensions.cs +++ b/src/Core/Billing/Extensions/CustomerExtensions.cs @@ -21,7 +21,7 @@ public static class CustomerExtensions /// /// public static bool HasTaxLocationVerified(this Customer customer) => - customer?.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; + customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation; public static decimal GetBillingBalance(this Customer customer) { diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 26815d7df0..17285e0676 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -4,6 +4,7 @@ using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; namespace Bit.Core.Billing.Extensions; @@ -18,6 +19,9 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddKeyedTransient(AutomaticTaxFactory.PersonalUse); + services.AddKeyedTransient(AutomaticTaxFactory.BusinessUse); + services.AddTransient(); services.AddLicenseServices(); services.AddPricingClient(); } diff --git a/src/Core/Billing/Extensions/SubscriptionCreateOptionsExtensions.cs b/src/Core/Billing/Extensions/SubscriptionCreateOptionsExtensions.cs deleted file mode 100644 index d76a0553a3..0000000000 --- a/src/Core/Billing/Extensions/SubscriptionCreateOptionsExtensions.cs +++ /dev/null @@ -1,26 +0,0 @@ -using Stripe; - -namespace Bit.Core.Billing.Extensions; - -public static class SubscriptionCreateOptionsExtensions -{ - /// - /// Attempts to enable automatic tax for given new subscription options. - /// - /// - /// The existing customer. - /// Returns true when successful, false when conditions are not met. - public static bool EnableAutomaticTax(this SubscriptionCreateOptions options, Customer customer) - { - // We might only need to check the automatic tax status. - if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country)) - { - return false; - } - - options.DefaultTaxRates = []; - options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; - - return true; - } -} diff --git a/src/Core/Billing/Licenses/LicenseConstants.cs b/src/Core/Billing/Licenses/LicenseConstants.cs index 564019affc..50510914a5 100644 --- a/src/Core/Billing/Licenses/LicenseConstants.cs +++ b/src/Core/Billing/Licenses/LicenseConstants.cs @@ -36,6 +36,7 @@ public static class OrganizationLicenseConstants public const string SmServiceAccounts = nameof(SmServiceAccounts); public const string LimitCollectionCreationDeletion = nameof(LimitCollectionCreationDeletion); public const string AllowAdminAccessToAllCollectionItems = nameof(AllowAdminAccessToAllCollectionItems); + public const string UseRiskInsights = nameof(UseRiskInsights); public const string Expires = nameof(Expires); public const string Refresh = nameof(Refresh); public const string ExpirationWithoutGracePeriod = nameof(ExpirationWithoutGracePeriod); diff --git a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs index e436102012..62e1889564 100644 --- a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs +++ b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs @@ -47,6 +47,7 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory prices) + { + Subscriber = subscriber; + Prices = prices; + } + + public AutomaticTaxFactoryParameters(IEnumerable prices) + { + Prices = prices; + } + + public ISubscriber? Subscriber { get; init; } + + public PlanType? PlanType { get; init; } + + public IEnumerable? Prices { get; init; } +} diff --git a/src/Core/Billing/Services/Contracts/ChangeProviderPlansCommand.cs b/src/Core/Billing/Services/Contracts/ChangeProviderPlansCommand.cs index 3e8fffdd11..385782c8ad 100644 --- a/src/Core/Billing/Services/Contracts/ChangeProviderPlansCommand.cs +++ b/src/Core/Billing/Services/Contracts/ChangeProviderPlansCommand.cs @@ -1,8 +1,9 @@ -using Bit.Core.Billing.Enums; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Enums; namespace Bit.Core.Billing.Services.Contracts; public record ChangeProviderPlanCommand( + Provider Provider, Guid ProviderPlanId, - PlanType NewPlan, - string GatewaySubscriptionId); + PlanType NewPlan); diff --git a/src/Core/Billing/Services/Contracts/UpdateProviderSeatMinimumsCommand.cs b/src/Core/Billing/Services/Contracts/UpdateProviderSeatMinimumsCommand.cs index 86a596ffb6..2d2535b60a 100644 --- a/src/Core/Billing/Services/Contracts/UpdateProviderSeatMinimumsCommand.cs +++ b/src/Core/Billing/Services/Contracts/UpdateProviderSeatMinimumsCommand.cs @@ -1,10 +1,10 @@ -using Bit.Core.Billing.Enums; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Enums; namespace Bit.Core.Billing.Services.Contracts; -/// The ID of the provider to update the seat minimums for. +/// The provider to update the seat minimums for. /// The new seat minimums for the provider. public record UpdateProviderSeatMinimumsCommand( - Guid Id, - string GatewaySubscriptionId, + Provider Provider, IReadOnlyCollection<(PlanType Plan, int SeatsMinimum)> Configuration); diff --git a/src/Core/Billing/Services/IAutomaticTaxFactory.cs b/src/Core/Billing/Services/IAutomaticTaxFactory.cs new file mode 100644 index 0000000000..c52a8f2671 --- /dev/null +++ b/src/Core/Billing/Services/IAutomaticTaxFactory.cs @@ -0,0 +1,11 @@ +using Bit.Core.Billing.Services.Contracts; + +namespace Bit.Core.Billing.Services; + +/// +/// Responsible for defining the correct automatic tax strategy for either personal use of business use. +/// +public interface IAutomaticTaxFactory +{ + Task CreateAsync(AutomaticTaxFactoryParameters parameters); +} diff --git a/src/Core/Billing/Services/IAutomaticTaxStrategy.cs b/src/Core/Billing/Services/IAutomaticTaxStrategy.cs new file mode 100644 index 0000000000..292f2d0939 --- /dev/null +++ b/src/Core/Billing/Services/IAutomaticTaxStrategy.cs @@ -0,0 +1,33 @@ +#nullable enable +using Stripe; + +namespace Bit.Core.Billing.Services; + +public interface IAutomaticTaxStrategy +{ + /// + /// + /// + /// + /// + /// Returns if changes are to be applied to the subscription, returns null + /// otherwise. + /// + SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription); + + /// + /// Modifies an existing object with the automatic tax flag set correctly. + /// + /// + /// + void SetCreateOptions(SubscriptionCreateOptions options, Customer customer); + + /// + /// Modifies an existing object with the automatic tax flag set correctly. + /// + /// + /// + void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription); + + void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options); +} diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/AutomaticTaxFactory.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/AutomaticTaxFactory.cs new file mode 100644 index 0000000000..133cd2c7a7 --- /dev/null +++ b/src/Core/Billing/Services/Implementations/AutomaticTax/AutomaticTaxFactory.cs @@ -0,0 +1,50 @@ +#nullable enable +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services.Contracts; +using Bit.Core.Entities; +using Bit.Core.Services; + +namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; + +public class AutomaticTaxFactory( + IFeatureService featureService, + IPricingClient pricingClient) : IAutomaticTaxFactory +{ + public const string BusinessUse = "business-use"; + public const string PersonalUse = "personal-use"; + + private readonly Lazy>> _personalUsePlansTask = new(async () => + { + var plans = await Task.WhenAll( + pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), + pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually)); + + return plans.Select(plan => plan.PasswordManager.StripePlanId); + }); + + public async Task CreateAsync(AutomaticTaxFactoryParameters parameters) + { + if (parameters.Subscriber is User) + { + return new PersonalUseAutomaticTaxStrategy(featureService); + } + + if (parameters.PlanType.HasValue) + { + var plan = await pricingClient.GetPlanOrThrow(parameters.PlanType.Value); + return plan.CanBeUsedByBusiness + ? new BusinessUseAutomaticTaxStrategy(featureService) + : new PersonalUseAutomaticTaxStrategy(featureService); + } + + var personalUsePlans = await _personalUsePlansTask.Value; + + if (parameters.Prices != null && parameters.Prices.Any(x => personalUsePlans.Any(y => y == x))) + { + return new PersonalUseAutomaticTaxStrategy(featureService); + } + + return new BusinessUseAutomaticTaxStrategy(featureService); + } +} diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategy.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategy.cs new file mode 100644 index 0000000000..40eb6e4540 --- /dev/null +++ b/src/Core/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategy.cs @@ -0,0 +1,96 @@ +#nullable enable +using Bit.Core.Billing.Extensions; +using Bit.Core.Services; +using Stripe; + +namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; + +public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : IAutomaticTaxStrategy +{ + public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription) + { + if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + { + return null; + } + + var shouldBeEnabled = ShouldBeEnabled(subscription.Customer); + if (subscription.AutomaticTax.Enabled == shouldBeEnabled) + { + return null; + } + + var options = new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = shouldBeEnabled + }, + DefaultTaxRates = [] + }; + + return options; + } + + public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) + { + options.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = ShouldBeEnabled(customer) + }; + } + + public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription) + { + if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + { + return; + } + + var shouldBeEnabled = ShouldBeEnabled(subscription.Customer); + + if (subscription.AutomaticTax.Enabled == shouldBeEnabled) + { + return; + } + + options.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = shouldBeEnabled + }; + options.DefaultTaxRates = []; + } + + public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options) + { + options.AutomaticTax ??= new InvoiceAutomaticTaxOptions(); + + if (options.CustomerDetails.Address.Country == "US") + { + options.AutomaticTax.Enabled = true; + return; + } + + options.AutomaticTax.Enabled = options.CustomerDetails.TaxIds != null && options.CustomerDetails.TaxIds.Any(); + } + + private bool ShouldBeEnabled(Customer customer) + { + if (!customer.HasTaxLocationVerified()) + { + return false; + } + + if (customer.Address.Country == "US") + { + return true; + } + + if (customer.TaxIds == null) + { + throw new ArgumentNullException(nameof(customer.TaxIds), "`customer.tax_ids` must be expanded."); + } + + return customer.TaxIds.Any(); + } +} diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategy.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategy.cs new file mode 100644 index 0000000000..15ee1adf8f --- /dev/null +++ b/src/Core/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategy.cs @@ -0,0 +1,64 @@ +#nullable enable +using Bit.Core.Billing.Extensions; +using Bit.Core.Services; +using Stripe; + +namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; + +public class PersonalUseAutomaticTaxStrategy(IFeatureService featureService) : IAutomaticTaxStrategy +{ + public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) + { + options.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = ShouldBeEnabled(customer) + }; + } + + public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription) + { + if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + { + return; + } + options.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = ShouldBeEnabled(subscription.Customer) + }; + options.DefaultTaxRates = []; + } + + public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription) + { + if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + { + return null; + } + + if (subscription.AutomaticTax.Enabled == ShouldBeEnabled(subscription.Customer)) + { + return null; + } + + var options = new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = ShouldBeEnabled(subscription.Customer), + }, + DefaultTaxRates = [] + }; + + return options; + } + + public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options) + { + options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; + } + + private static bool ShouldBeEnabled(Customer customer) + { + return customer.HasTaxLocationVerified(); + } +} diff --git a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs index 8b773f1cef..a4d22cfa3e 100644 --- a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs +++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs @@ -1,9 +1,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -23,6 +25,7 @@ namespace Bit.Core.Billing.Services.Implementations; public class OrganizationBillingService( IBraintreeGateway braintreeGateway, + IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, IOrganizationRepository organizationRepository, @@ -30,7 +33,8 @@ public class OrganizationBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - ITaxService taxService) : IOrganizationBillingService + ITaxService taxService, + IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService { public async Task Finalize(OrganizationSale sale) { @@ -143,7 +147,7 @@ public class OrganizationBillingService( Coupon = customerSetup.Coupon, Description = organization.DisplayBusinessName(), Email = organization.BillingEmail, - Expand = ["tax"], + Expand = ["tax", "tax_ids"], InvoiceSettings = new CustomerInvoiceSettingsOptions { CustomFields = [ @@ -369,21 +373,8 @@ public class OrganizationBillingService( } } - var customerHasTaxInfo = customer is - { - Address: - { - Country: not null and not "", - PostalCode: not null and not "" - } - }; - var subscriptionCreateOptions = new SubscriptionCreateOptions { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = customerHasTaxInfo - }, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, @@ -395,6 +386,18 @@ public class OrganizationBillingService( TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays }; + if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType); + var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); + automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + } + else + { + subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions(); + subscriptionCreateOptions.AutomaticTax.Enabled = customer.HasBillingLocation(); + } + return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); } diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index c00a151aa1..6746a8cc98 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -2,6 +2,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -9,6 +10,7 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Braintree; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; using Customer = Stripe.Customer; @@ -20,12 +22,14 @@ using static Utilities; public class PremiumUserBillingService( IBraintreeGateway braintreeGateway, + IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - IUserRepository userRepository) : IPremiumUserBillingService + IUserRepository userRepository, + [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService { public async Task Credit(User user, decimal amount) { @@ -318,10 +322,6 @@ public class PremiumUserBillingService( var subscriptionCreateOptions = new SubscriptionCreateOptions { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported, - }, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, @@ -335,6 +335,18 @@ public class PremiumUserBillingService( OffSession = true }; + if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + } + else + { + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported, + }; + } + var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); if (usingPayPal) diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index b2dca19e80..e4b0594433 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -1,6 +1,7 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -20,11 +21,13 @@ namespace Bit.Core.Billing.Services.Implementations; public class SubscriberService( IBraintreeGateway braintreeGateway, + IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ITaxService taxService) : ISubscriberService + ITaxService taxService, + IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService { public async Task CancelSubscription( ISubscriber subscriber, @@ -438,7 +441,8 @@ public class SubscriberService( ArgumentNullException.ThrowIfNull(subscriber); ArgumentNullException.ThrowIfNull(tokenizedPaymentSource); - var customer = await GetCustomerOrThrow(subscriber); + var customerGetOptions = new CustomerGetOptions { Expand = ["tax", "tax_ids"] }; + var customer = await GetCustomerOrThrow(subscriber, customerGetOptions); var (type, token) = tokenizedPaymentSource; @@ -597,7 +601,7 @@ public class SubscriberService( Expand = ["subscriptions", "tax", "tax_ids"] }); - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + customer = await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Address = new AddressOptions { @@ -607,7 +611,8 @@ public class SubscriberService( Line2 = taxInformation.Line2, City = taxInformation.City, State = taxInformation.State - } + }, + Expand = ["subscriptions", "tax", "tax_ids"] }); var taxId = customer.TaxIds?.FirstOrDefault(); @@ -661,21 +666,42 @@ public class SubscriberService( } } - if (SubscriberIsEligibleForAutomaticTax(subscriber, customer)) + if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) { - await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, - new SubscriptionUpdateOptions + if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + var subscriptionGetOptions = new SubscriptionGetOptions { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); + Expand = ["customer.tax", "customer.tax_ids"] + }; + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); + var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); + var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription); + if (automaticTaxOptions?.AutomaticTax?.Enabled != null) + { + await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions); + } + } } + else + { + if (SubscriberIsEligibleForAutomaticTax(subscriber, customer)) + { + await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } - return; + return; - bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer) - => !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) && - (localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) && - localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; + bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer) + => !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) && + (localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) && + localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; + } } public async Task VerifyBankAccount( diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index b772002dbb..310b917bf7 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -148,6 +148,8 @@ public static class FeatureFlagKeys public const string P15179_AddExistingOrgsFromProviderPortal = "pm-15179-add-existing-orgs-from-provider-portal"; public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features"; public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method"; + public const string PM19147_AutomaticTaxImprovements = "pm-19147-automatic-tax-improvements"; + public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; /* Key Management Team */ public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; @@ -169,6 +171,7 @@ public static class FeatureFlagKeys public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication"; public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync"; public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias"; + public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias"; /* Platform Team */ diff --git a/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs b/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs index 7cb1c273a3..f81baf6fab 100644 --- a/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs +++ b/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs @@ -20,6 +20,7 @@ public class RotateUserAccountKeysData public IEnumerable EmergencyAccesses { get; set; } public IReadOnlyList OrganizationUsers { get; set; } public IEnumerable WebAuthnKeys { get; set; } + public IEnumerable DeviceKeys { get; set; } // User vault data encrypted by the userkey public IEnumerable Ciphers { get; set; } diff --git a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs index f4dcf31d5c..6967c9bf85 100644 --- a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs +++ b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs @@ -20,6 +20,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand private readonly ISendRepository _sendRepository; private readonly IEmergencyAccessRepository _emergencyAccessRepository; private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IDeviceRepository _deviceRepository; private readonly IPushNotificationService _pushService; private readonly IdentityErrorDescriber _identityErrorDescriber; private readonly IWebAuthnCredentialRepository _credentialRepository; @@ -42,6 +43,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand public RotateUserAccountKeysCommand(IUserService userService, IUserRepository userRepository, ICipherRepository cipherRepository, IFolderRepository folderRepository, ISendRepository sendRepository, IEmergencyAccessRepository emergencyAccessRepository, IOrganizationUserRepository organizationUserRepository, + IDeviceRepository deviceRepository, IPasswordHasher passwordHasher, IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository) { @@ -52,6 +54,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand _sendRepository = sendRepository; _emergencyAccessRepository = emergencyAccessRepository; _organizationUserRepository = organizationUserRepository; + _deviceRepository = deviceRepository; _pushService = pushService; _identityErrorDescriber = errors; _credentialRepository = credentialRepository; @@ -127,6 +130,11 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand saveEncryptedDataActions.Add(_credentialRepository.UpdateKeysForRotationAsync(user.Id, model.WebAuthnKeys)); } + if (model.DeviceKeys.Any()) + { + saveEncryptedDataActions.Add(_deviceRepository.UpdateKeysForRotationAsync(user.Id, model.DeviceKeys)); + } + await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions); await _pushService.PushLogOutAsync(user.Id); return IdentityResult.Success; diff --git a/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.html.hbs b/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.html.hbs index 930d39eeee..67537b81a7 100644 --- a/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.html.hbs +++ b/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.html.hbs @@ -6,11 +6,8 @@ -
- {{OrgName}} has identified {{TaskCount}} critical login{{#if TaskCountPlural}}s{{/if}} that require{{#unless - TaskCountPlural}}s{{/unless}} a - password change + + {{OrgName}} has identified {{TaskCount}} critical {{plurality TaskCount "login" "logins"}} that {{plurality TaskCount "requires" "require"}} a password change
diff --git a/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.text.hbs b/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.text.hbs index f9befac46c..009e2b923f 100644 --- a/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.text.hbs +++ b/src/Core/MailTemplates/Handlebars/Layouts/SecurityTasks.text.hbs @@ -1,7 +1,5 @@ {{#>FullTextLayout}} -{{OrgName}} has identified {{TaskCount}} critical login{{#if TaskCountPlural}}s{{/if}} that require{{#unless -TaskCountPlural}}s{{/unless}} a -password change +{{OrgName}} has identified {{TaskCount}} critical {{plurality TaskCount "login" "logins"}} that {{plurality TaskCount "requires" "require"}} a password change {{>@partial-block}} diff --git a/src/Core/Models/Business/ProviderSubscriptionUpdate.cs b/src/Core/Models/Business/ProviderSubscriptionUpdate.cs deleted file mode 100644 index 1fd833ca1f..0000000000 --- a/src/Core/Models/Business/ProviderSubscriptionUpdate.cs +++ /dev/null @@ -1,62 +0,0 @@ -using Bit.Core.Billing; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; -using Stripe; -using Plan = Bit.Core.Models.StaticStore.Plan; - -namespace Bit.Core.Models.Business; - -public class ProviderSubscriptionUpdate : SubscriptionUpdate -{ - private readonly string _planId; - private readonly int _previouslyPurchasedSeats; - private readonly int _newlyPurchasedSeats; - - protected override List PlanIds => [_planId]; - - public ProviderSubscriptionUpdate( - Plan plan, - int previouslyPurchasedSeats, - int newlyPurchasedSeats) - { - if (!plan.Type.SupportsConsolidatedBilling()) - { - throw new BillingException( - message: $"Cannot create a {nameof(ProviderSubscriptionUpdate)} for {nameof(PlanType)} that doesn't support consolidated billing"); - } - - _planId = plan.PasswordManager.StripeProviderPortalSeatPlanId; - _previouslyPurchasedSeats = previouslyPurchasedSeats; - _newlyPurchasedSeats = newlyPurchasedSeats; - } - - public override List RevertItemsOptions(Subscription subscription) - { - var subscriptionItem = FindSubscriptionItem(subscription, _planId); - - return - [ - new SubscriptionItemOptions - { - Id = subscriptionItem.Id, - Price = _planId, - Quantity = _previouslyPurchasedSeats - } - ]; - } - - public override List UpgradeItemsOptions(Subscription subscription) - { - var subscriptionItem = FindSubscriptionItem(subscription, _planId); - - return - [ - new SubscriptionItemOptions - { - Id = subscriptionItem.Id, - Price = _planId, - Quantity = _newlyPurchasedSeats - } - ]; - } -} diff --git a/src/Core/Models/Mail/SecurityTaskNotificationViewModel.cs b/src/Core/Models/Mail/SecurityTaskNotificationViewModel.cs index 9b4ede6e01..d41ca41146 100644 --- a/src/Core/Models/Mail/SecurityTaskNotificationViewModel.cs +++ b/src/Core/Models/Mail/SecurityTaskNotificationViewModel.cs @@ -6,8 +6,6 @@ public class SecurityTaskNotificationViewModel : BaseMailModel public int TaskCount { get; set; } - public bool TaskCountPlural => TaskCount != 1; - public List AdminOwnerEmails { get; set; } public string ReviewPasswordsUrl => $"{WebVaultUrl}/browser-extension-prompt"; diff --git a/src/Core/NotificationCenter/Models/Data/NotificationStatusDetails.cs b/src/Core/NotificationCenter/Models/Data/NotificationStatusDetails.cs index d48985e725..5ad8decb94 100644 --- a/src/Core/NotificationCenter/Models/Data/NotificationStatusDetails.cs +++ b/src/Core/NotificationCenter/Models/Data/NotificationStatusDetails.cs @@ -19,6 +19,7 @@ public class NotificationStatusDetails public string? Body { get; set; } public DateTime CreationDate { get; set; } public DateTime RevisionDate { get; set; } + public Guid? TaskId { get; set; } // Notification Status fields public DateTime? ReadDate { get; set; } public DateTime? DeletedDate { get; set; } diff --git a/src/Core/Repositories/IDeviceRepository.cs b/src/Core/Repositories/IDeviceRepository.cs index c9809c1de6..fc2f1556b7 100644 --- a/src/Core/Repositories/IDeviceRepository.cs +++ b/src/Core/Repositories/IDeviceRepository.cs @@ -1,5 +1,6 @@ using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; +using Bit.Core.KeyManagement.UserKey; #nullable enable @@ -16,4 +17,5 @@ public interface IDeviceRepository : IRepository // other requests. Task> GetManyByUserIdWithDeviceAuth(Guid userId); Task ClearPushTokenAsync(Guid id); + UpdateEncryptedDataForKeyRotation UpdateKeysForRotationAsync(Guid userId, IEnumerable devices); } diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index e3495c0e65..bd7efdbad4 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -1,5 +1,4 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Api.Requests.Accounts; using Bit.Core.Billing.Models.Api.Requests.Organizations; @@ -25,11 +24,6 @@ public interface IPaymentService int? newlyPurchasedAdditionalSecretsManagerServiceAccounts, int newlyPurchasedAdditionalStorage); Task AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats); - Task AdjustSeats( - Provider provider, - Plan plan, - int currentlySubscribedSeats, - int newlySubscribedSeats); Task AdjustSmSeatsAsync(Organization organization, Plan plan, int additionalSeats); Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId); diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index 430636f44d..a551342324 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -794,6 +794,29 @@ public class HandlebarsMailService : IMailService writer.WriteSafeString($"{outputMessage}"); }); + + // Returns the singular or plural form of a word based on the provided numeric value. + Handlebars.RegisterHelper("plurality", (writer, context, parameters) => + { + if (parameters.Length != 3) + { + writer.WriteSafeString(string.Empty); + return; + } + + var numeric = parameters[0]; + var singularText = parameters[1].ToString(); + var pluralText = parameters[2].ToString(); + + if (numeric is int number) + { + writer.WriteSafeString(number == 1 ? singularText : pluralText); + } + else + { + writer.WriteSafeString(string.Empty); + } + }); } public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index ca377407f4..d8889bca26 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1,5 +1,4 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; @@ -9,6 +8,8 @@ using Bit.Core.Billing.Models.Api.Responses; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Contracts; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -16,6 +17,7 @@ using Bit.Core.Models.BitStripe; using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Settings; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; using PaymentMethod = Stripe.PaymentMethod; @@ -36,6 +38,8 @@ public class StripePaymentService : IPaymentService private readonly ITaxService _taxService; private readonly ISubscriberService _subscriberService; private readonly IPricingClient _pricingClient; + private readonly IAutomaticTaxFactory _automaticTaxFactory; + private readonly IAutomaticTaxStrategy _personalUseTaxStrategy; public StripePaymentService( ITransactionRepository transactionRepository, @@ -46,7 +50,9 @@ public class StripePaymentService : IPaymentService IFeatureService featureService, ITaxService taxService, ISubscriberService subscriberService, - IPricingClient pricingClient) + IPricingClient pricingClient, + IAutomaticTaxFactory automaticTaxFactory, + [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy personalUseTaxStrategy) { _transactionRepository = transactionRepository; _logger = logger; @@ -57,6 +63,8 @@ public class StripePaymentService : IPaymentService _taxService = taxService; _subscriberService = subscriberService; _pricingClient = pricingClient; + _automaticTaxFactory = automaticTaxFactory; + _personalUseTaxStrategy = personalUseTaxStrategy; } private async Task ChangeOrganizationSponsorship( @@ -91,9 +99,7 @@ public class StripePaymentService : IPaymentService SubscriptionUpdate subscriptionUpdate, bool invoiceNow = false) { // remember, when in doubt, throw - var subGetOptions = new SubscriptionGetOptions(); - // subGetOptions.AddExpand("customer"); - subGetOptions.AddExpand("customer.tax"); + var subGetOptions = new SubscriptionGetOptions { Expand = ["customer.tax", "customer.tax_ids"] }; var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subGetOptions); if (sub == null) { @@ -124,7 +130,19 @@ public class StripePaymentService : IPaymentService new SubscriptionPendingInvoiceItemIntervalOptions { Interval = "month" }; } - subUpdateOptions.EnableAutomaticTax(sub.Customer, sub); + if (subscriptionUpdate is CompleteSubscriptionUpdate) + { + if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + { + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, updatedItemOptions.Select(x => x.Plan ?? x.Price)); + var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); + automaticTaxStrategy.SetUpdateOptions(subUpdateOptions, sub); + } + else + { + subUpdateOptions.EnableAutomaticTax(sub.Customer, sub); + } + } if (!subscriptionUpdate.UpdateNeeded(sub)) { @@ -232,18 +250,6 @@ public class StripePaymentService : IPaymentService public Task AdjustSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats) => FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats)); - public Task AdjustSeats( - Provider provider, - StaticStore.Plan plan, - int currentlySubscribedSeats, - int newlySubscribedSeats) - => FinalizeSubscriptionChangeAsync( - provider, - new ProviderSubscriptionUpdate( - plan, - currentlySubscribedSeats, - newlySubscribedSeats)); - public Task AdjustSmSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats) => FinalizeSubscriptionChangeAsync( organization, @@ -811,21 +817,46 @@ public class StripePaymentService : IPaymentService }); } - if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) && - customer.Subscriptions.Any(sub => - sub.Id == subscriber.GatewaySubscriptionId && - !sub.AutomaticTax.Enabled) && - customer.HasTaxLocationVerified()) + if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) { - var subscriptionUpdateOptions = new SubscriptionUpdateOptions + if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, - DefaultTaxRates = [] - }; + var subscriptionGetOptions = new SubscriptionGetOptions + { + Expand = ["customer.tax", "customer.tax_ids"] + }; + var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); - _ = await _stripeAdapter.SubscriptionUpdateAsync( - subscriber.GatewaySubscriptionId, - subscriptionUpdateOptions); + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); + var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); + var subscriptionUpdateOptions = automaticTaxStrategy.GetUpdateOptions(subscription); + + if (subscriptionUpdateOptions != null) + { + _ = await _stripeAdapter.SubscriptionUpdateAsync( + subscriber.GatewaySubscriptionId, + subscriptionUpdateOptions); + } + } + } + else + { + if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) && + customer.Subscriptions.Any(sub => + sub.Id == subscriber.GatewaySubscriptionId && + !sub.AutomaticTax.Enabled) && + customer.HasTaxLocationVerified()) + { + var subscriptionUpdateOptions = new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, + DefaultTaxRates = [] + }; + + _ = await _stripeAdapter.SubscriptionUpdateAsync( + subscriber.GatewaySubscriptionId, + subscriptionUpdateOptions); + } } } catch @@ -1214,6 +1245,8 @@ public class StripePaymentService : IPaymentService } } + _personalUseTaxStrategy.SetInvoiceCreatePreviewOptions(options); + try { var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); @@ -1256,10 +1289,6 @@ public class StripePaymentService : IPaymentService var options = new InvoiceCreatePreviewOptions { - AutomaticTax = new InvoiceAutomaticTaxOptions - { - Enabled = true, - }, Currency = "usd", SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { @@ -1347,9 +1376,11 @@ public class StripePaymentService : IPaymentService ]; } + Customer gatewayCustomer = null; + if (!string.IsNullOrWhiteSpace(gatewayCustomerId)) { - var gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); if (gatewayCustomer.Discount != null) { @@ -1367,6 +1398,10 @@ public class StripePaymentService : IPaymentService } } + var automaticTaxFactoryParameters = new AutomaticTaxFactoryParameters(parameters.PasswordManager.Plan); + var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxFactoryParameters); + automaticTaxStrategy.SetInvoiceCreatePreviewOptions(options); + try { var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); diff --git a/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs b/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs index 4abf4a4649..723200ff1c 100644 --- a/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs @@ -1,8 +1,10 @@ using System.Data; using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; +using Bit.Core.KeyManagement.UserKey; using Bit.Core.Repositories; using Bit.Core.Settings; +using Bit.Core.Utilities; using Dapper; using Microsoft.Data.SqlClient; @@ -109,4 +111,35 @@ public class DeviceRepository : Repository, IDeviceRepository commandType: CommandType.StoredProcedure); } } + + public UpdateEncryptedDataForKeyRotation UpdateKeysForRotationAsync(Guid userId, IEnumerable devices) + { + return async (SqlConnection connection, SqlTransaction transaction) => + { + const string sql = @" + UPDATE D + SET + D.[EncryptedPublicKey] = UD.[encryptedPublicKey], + D.[EncryptedUserKey] = UD.[encryptedUserKey] + FROM + [dbo].[Device] D + INNER JOIN + OPENJSON(@DeviceCredentials) + WITH ( + id UNIQUEIDENTIFIER, + encryptedPublicKey NVARCHAR(MAX), + encryptedUserKey NVARCHAR(MAX) + ) UD + ON UD.[id] = D.[Id] + WHERE + D.[UserId] = @UserId"; + var deviceCredentials = CoreHelpers.ClassToJsonData(devices); + + await connection.ExecuteAsync( + sql, + new { UserId = userId, DeviceCredentials = deviceCredentials }, + transaction: transaction, + commandType: CommandType.Text); + }; + } } diff --git a/src/Infrastructure.EntityFramework/NotificationCenter/Repositories/Queries/NotificationStatusDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/NotificationCenter/Repositories/Queries/NotificationStatusDetailsViewQuery.cs index 2f8bade1d3..41f8610101 100644 --- a/src/Infrastructure.EntityFramework/NotificationCenter/Repositories/Queries/NotificationStatusDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/NotificationCenter/Repositories/Queries/NotificationStatusDetailsViewQuery.cs @@ -52,6 +52,7 @@ public class NotificationStatusDetailsViewQuery(Guid userId, ClientType clientTy ClientType = x.n.ClientType, UserId = x.n.UserId, OrganizationId = x.n.OrganizationId, + TaskId = x.n.TaskId, Title = x.n.Title, Body = x.n.Body, CreationDate = x.n.CreationDate, diff --git a/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs b/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs index ad31d0fb8b..19f38c6098 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs @@ -1,5 +1,6 @@ using AutoMapper; using Bit.Core.Auth.Models.Data; +using Bit.Core.KeyManagement.UserKey; using Bit.Core.Repositories; using Bit.Core.Settings; using Bit.Infrastructure.EntityFramework.Auth.Repositories.Queries; @@ -91,4 +92,30 @@ public class DeviceRepository : Repository, return await query.GetQuery(dbContext, userId, expirationMinutes).ToListAsync(); } } + + public UpdateEncryptedDataForKeyRotation UpdateKeysForRotationAsync(Guid userId, IEnumerable devices) + { + return async (_, _) => + { + var deviceUpdates = devices.ToList(); + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + var userDevices = await GetDbSet(dbContext) + .Where(device => device.UserId == userId) + .ToListAsync(); + var userDevicesWithUpdatesPending = userDevices + .Where(existingDevice => deviceUpdates.Any(updatedDevice => updatedDevice.Id == existingDevice.Id)) + .ToList(); + + foreach (var deviceToUpdate in userDevicesWithUpdatesPending) + { + var deviceUpdate = deviceUpdates.First(deviceUpdate => deviceUpdate.Id == deviceToUpdate.Id); + deviceToUpdate.EncryptedPublicKey = deviceUpdate.EncryptedPublicKey; + deviceToUpdate.EncryptedUserKey = deviceUpdate.EncryptedUserKey; + } + + await dbContext.SaveChangesAsync(); + }; + } + } diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index 8fa74f7b84..441842da3b 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -135,6 +135,11 @@ public static class HubHelpers } break; + case PushType.PendingSecurityTasks: + var pendingTasksData = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); + await hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString()) + .SendAsync(_receiveMessageMethod, pendingTasksData, cancellationToken); + break; default: break; } diff --git a/src/Sql/NotificationCenter/dbo/Views/NotificationStatusDetailsView.sql b/src/Sql/NotificationCenter/dbo/Views/NotificationStatusDetailsView.sql index 5264be2009..57298152c7 100644 --- a/src/Sql/NotificationCenter/dbo/Views/NotificationStatusDetailsView.sql +++ b/src/Sql/NotificationCenter/dbo/Views/NotificationStatusDetailsView.sql @@ -1,10 +1,20 @@ CREATE VIEW [dbo].[NotificationStatusDetailsView] AS SELECT - N.*, - NS.UserId AS NotificationStatusUserId, - NS.ReadDate, - NS.DeletedDate + N.[Id], + N.[Priority], + N.[Global], + N.[ClientType], + N.[UserId], + N.[OrganizationId], + N.[Title], + N.[Body], + N.[CreationDate], + N.[RevisionDate], + N.[TaskId], + NS.[UserId] AS [NotificationStatusUserId], + NS.[ReadDate], + NS.[DeletedDate] FROM [dbo].[Notification] AS N LEFT JOIN diff --git a/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs b/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs index 7c05e1d680..1b065adbd6 100644 --- a/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs +++ b/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs @@ -29,6 +29,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture _passwordHasher; private string _ownerEmail = null!; @@ -40,6 +41,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture(); + _deviceRepository = _factory.GetService(); _emergencyAccessRepository = _factory.GetService(); _organizationUserRepository = _factory.GetService(); _passwordHasher = _factory.GetService>(); @@ -238,10 +240,12 @@ public class AccountsKeyManagementControllerTests : IClassFixture sutProvider, User user, IEnumerable devices) + { + var userCiphers = devices.Select(c => new Device { Id = c.DeviceId, EncryptedPrivateKey = "EncryptedPrivateKey", EncryptedPublicKey = "EncryptedPublicKey", EncryptedUserKey = "EncryptedUserKey" }).ToList(); + sutProvider.GetDependency().GetManyByUserIdAsync(user.Id) + .Returns(userCiphers); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.ValidateAsync(user, Enumerable.Empty())); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_SentDevicesTrustedButDatabaseUntrusted_Throws( + SutProvider sutProvider, User user, IEnumerable devices) + { + var userCiphers = devices.Select(c => new Device { Id = c.DeviceId, EncryptedPrivateKey = "Key", EncryptedPublicKey = "Key", EncryptedUserKey = "Key" }).ToList(); + sutProvider.GetDependency().GetManyByUserIdAsync(user.Id) + .Returns(userCiphers); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.ValidateAsync(user, [ + new OtherDeviceKeysUpdateRequestModel { DeviceId = userCiphers.First().Id, EncryptedPublicKey = null, EncryptedUserKey = null } + ])); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_Validates( + SutProvider sutProvider, User user, IEnumerable devices) + { + var userCiphers = devices.Select(c => new Device { Id = c.DeviceId, EncryptedPrivateKey = "Key", EncryptedPublicKey = "Key", EncryptedUserKey = "Key" }).ToList().Slice(0, 1); + sutProvider.GetDependency().GetManyByUserIdAsync(user.Id) + .Returns(userCiphers); + Assert.NotEmpty(await sutProvider.Sut.ValidateAsync(user, [ + new OtherDeviceKeysUpdateRequestModel { DeviceId = userCiphers.First().Id, EncryptedPublicKey = "Key", EncryptedUserKey = "Key" } + ])); + } +} diff --git a/test/Api.Test/NotificationCenter/Controllers/NotificationsControllerTests.cs b/test/Api.Test/NotificationCenter/Controllers/NotificationsControllerTests.cs index b8b21ef419..094ef2918e 100644 --- a/test/Api.Test/NotificationCenter/Controllers/NotificationsControllerTests.cs +++ b/test/Api.Test/NotificationCenter/Controllers/NotificationsControllerTests.cs @@ -67,6 +67,7 @@ public class NotificationsControllerTests Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date); Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate); Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate); + Assert.Equal(expectedNotificationStatusDetails.TaskId, notificationResponseModel.TaskId); }); Assert.Null(listResponse.ContinuationToken); @@ -116,6 +117,7 @@ public class NotificationsControllerTests Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date); Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate); Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate); + Assert.Equal(expectedNotificationStatusDetails.TaskId, notificationResponseModel.TaskId); }); Assert.Equal("2", listResponse.ContinuationToken); @@ -164,6 +166,7 @@ public class NotificationsControllerTests Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date); Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate); Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate); + Assert.Equal(expectedNotificationStatusDetails.TaskId, notificationResponseModel.TaskId); }); Assert.Null(listResponse.ContinuationToken); diff --git a/test/Api.Test/NotificationCenter/Models/Response/NotificationResponseModelTests.cs b/test/Api.Test/NotificationCenter/Models/Response/NotificationResponseModelTests.cs index f0dfc03fec..171b972575 100644 --- a/test/Api.Test/NotificationCenter/Models/Response/NotificationResponseModelTests.cs +++ b/test/Api.Test/NotificationCenter/Models/Response/NotificationResponseModelTests.cs @@ -26,6 +26,7 @@ public class NotificationResponseModelTests ClientType = ClientType.All, Title = "Test Title", Body = "Test Body", + TaskId = Guid.NewGuid(), RevisionDate = DateTime.UtcNow - TimeSpan.FromMinutes(3), ReadDate = DateTime.UtcNow - TimeSpan.FromMinutes(1), DeletedDate = DateTime.UtcNow, @@ -39,5 +40,6 @@ public class NotificationResponseModelTests Assert.Equal(model.Date, notificationStatusDetails.RevisionDate); Assert.Equal(model.ReadDate, notificationStatusDetails.ReadDate); Assert.Equal(model.DeletedDate, notificationStatusDetails.DeletedDate); + Assert.Equal(model.TaskId, notificationStatusDetails.TaskId); } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs index 726664849d..f91ca779a8 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs @@ -471,10 +471,11 @@ public class RestoreOrganizationUserCommandTests Organization organization, Organization otherOrganization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, - [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.Owner)] OrganizationUser organizationUser, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserOwnerFromDifferentOrg, SutProvider sutProvider) { + organization.PlanType = PlanType.Free; organizationUser.Email = null; // this is required to mock that the user as had already been confirmed before the revoke orgUserOwnerFromDifferentOrg.UserId = organizationUser.UserId; @@ -506,6 +507,107 @@ public class RestoreOrganizationUserCommandTests Assert.Equal("User is an owner/admin of another free organization. Please have them upgrade to a paid plan to restore their account.", exception.Message); } + [Theory, BitAutoData] + public async Task RestoreUser_WhenUserOwningAnotherFreeOrganizationAndIsOnlyAUserInCurrentOrg_ThenUserShouldBeRestored( + Organization organization, + Organization otherOrganization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserOwnerFromDifferentOrg, + SutProvider sutProvider) + { + organization.PlanType = PlanType.Free; + organizationUser.Email = null; // this is required to mock that the user as had already been confirmed before the revoke + + orgUserOwnerFromDifferentOrg.UserId = organizationUser.UserId; + otherOrganization.Id = orgUserOwnerFromDifferentOrg.OrganizationId; + otherOrganization.PlanType = PlanType.Free; + + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + + var organizationUserRepository = sutProvider.GetDependency(); + organizationUserRepository + .GetManyByUserAsync(organizationUser.UserId.Value) + .Returns([orgUserOwnerFromDifferentOrg]); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(organizationUser.UserId.Value) + .Returns([otherOrganization]); + + sutProvider.GetDependency() + .GetPoliciesApplicableToUserAsync(organizationUser.UserId.Value, PolicyType.TwoFactorAuthentication, + Arg.Any()) + .Returns([ + new OrganizationUserPolicyDetails + { + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.TwoFactorAuthentication + } + ]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Is>(i => i.Contains(organizationUser.UserId.Value))) + .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) }); + + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + + await organizationUserRepository + .Received(1) + .RestoreAsync(organizationUser.Id, + Arg.Is(x => x != OrganizationUserStatusType.Revoked)); + } + + [Theory, BitAutoData] + public async Task RestoreUser_WhenUserOwningAnotherFreeOrganizationAndCurrentOrgIsNotFree_ThenUserShouldBeRestored( + Organization organization, + Organization otherOrganization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.Owner)] OrganizationUser organizationUser, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserOwnerFromDifferentOrg, + SutProvider sutProvider) + { + organization.PlanType = PlanType.EnterpriseAnnually2023; + + organizationUser.Email = null; // this is required to mock that the user as had already been confirmed before the revoke + + orgUserOwnerFromDifferentOrg.UserId = organizationUser.UserId; + otherOrganization.Id = orgUserOwnerFromDifferentOrg.OrganizationId; + otherOrganization.PlanType = PlanType.Free; + + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + + var organizationUserRepository = sutProvider.GetDependency(); + organizationUserRepository + .GetManyByUserAsync(organizationUser.UserId.Value) + .Returns([orgUserOwnerFromDifferentOrg]); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(organizationUser.UserId.Value) + .Returns([otherOrganization]); + + sutProvider.GetDependency() + .GetPoliciesApplicableToUserAsync(organizationUser.UserId.Value, PolicyType.TwoFactorAuthentication, + Arg.Any()) + .Returns([ + new OrganizationUserPolicyDetails + { + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.TwoFactorAuthentication + } + ]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Is>(i => i.Contains(organizationUser.UserId.Value))) + .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) }); + + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + + await organizationUserRepository + .Received(1) + .RestoreAsync(organizationUser.Id, + Arg.Is(x => x != OrganizationUserStatusType.Revoked)); + } + [Theory, BitAutoData] public async Task RestoreUsers_Success(Organization organization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, @@ -612,7 +714,7 @@ public class RestoreOrganizationUserCommandTests [Theory, BitAutoData] public async Task RestoreUsers_UserOwnsAnotherFreeOrganization_BlocksOwnerUserFromBeingRestored(Organization organization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, - [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.Owner)] OrganizationUser orgUser1, [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser2, [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser3, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserFromOtherOrg, @@ -637,7 +739,7 @@ public class RestoreOrganizationUserCommandTests organizationUserRepository .GetManyAsync(Arg.Is>(ids => ids.Contains(orgUser1.Id) && ids.Contains(orgUser2.Id) && ids.Contains(orgUser3.Id))) - .Returns(new[] { orgUser1, orgUser2, orgUser3 }); + .Returns([orgUser1, orgUser2, orgUser3]); userRepository.GetByIdAsync(orgUser2.UserId!.Value).Returns(new User { Email = "test@example.com" }); @@ -674,6 +776,110 @@ public class RestoreOrganizationUserCommandTests .RestoreAsync(orgUser1.Id, OrganizationUserStatusType.Confirmed); } + [Theory, BitAutoData] + public async Task RestoreUsers_UserOwnsAnotherFreeOrganizationButReactivatingOrgIsPaid_RestoresUser(Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.Owner)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserFromOtherOrg, + Organization otherOrganization, + SutProvider sutProvider) + { + // Arrange + organization.PlanType = PlanType.EnterpriseAnnually2023; + + RestoreUser_Setup(organization, owner, orgUser1, sutProvider); + var organizationUserRepository = sutProvider.GetDependency(); + var policyService = sutProvider.GetDependency(); + var userService = Substitute.For(); + + orgUser1.OrganizationId = organization.Id; + + orgUserFromOtherOrg.UserId = orgUser1.UserId; + + otherOrganization.Id = orgUserFromOtherOrg.OrganizationId; + otherOrganization.PlanType = PlanType.Free; + + organizationUserRepository + .GetManyAsync(Arg.Is>(ids => ids.Contains(orgUser1.Id))) + .Returns([orgUser1]); + + organizationUserRepository + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([orgUserFromOtherOrg]); + + sutProvider.GetDependency() + .GetManyByIdsAsync(Arg.Is>(ids => ids.Contains(orgUserFromOtherOrg.OrganizationId))) + .Returns([otherOrganization]); + + + // Setup 2FA policy + policyService.GetPoliciesApplicableToUserAsync(Arg.Any(), PolicyType.TwoFactorAuthentication, Arg.Any()) + .Returns([new OrganizationUserPolicyDetails { OrganizationId = organization.Id, PolicyType = PolicyType.TwoFactorAuthentication }]); + + // User1 has 2FA, User2 doesn't + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Is>(ids => ids.Contains(orgUser1.UserId!.Value))) + .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> + { + (orgUser1.UserId!.Value, true) + }); + + // Act + var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService); + + // Assert + Assert.Single(result); + Assert.Equal(string.Empty, result[0].Item2); + await organizationUserRepository + .Received(1) + .RestoreAsync(orgUser1.Id, Arg.Is(x => x != OrganizationUserStatusType.Revoked)); + } + + [Theory] + [BitAutoData] + public async Task RestoreUsers_UserOwnsAnotherOrganizationButIsOnlyUserOfCurrentOrganization_UserShouldBeRestored( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.User)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUserFromOtherOrg, + Organization otherOrganization, + SutProvider sutProvider) + { + // Arrange + organization.PlanType = PlanType.Free; + + RestoreUser_Setup(organization, owner, orgUser1, sutProvider); + var organizationUserRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + orgUser1.OrganizationId = organization.Id; + + orgUserFromOtherOrg.UserId = orgUser1.UserId; + + otherOrganization.Id = orgUserFromOtherOrg.OrganizationId; + otherOrganization.PlanType = PlanType.Free; + + organizationUserRepository + .GetManyAsync(Arg.Is>(ids => ids.Contains(orgUser1.Id))) + .Returns([orgUser1]); + + organizationUserRepository + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([orgUserFromOtherOrg]); + + sutProvider.GetDependency().GetPoliciesApplicableToUserAsync(Arg.Any(), PolicyType.TwoFactorAuthentication, Arg.Any()) + .Returns([new OrganizationUserPolicyDetails { OrganizationId = organization.Id, PolicyType = PolicyType.TwoFactorAuthentication }]); + + // Act + var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService); + + Assert.Single(result); + Assert.Equal(string.Empty, result[0].Item2); + await organizationUserRepository + .Received(1) + .RestoreAsync(orgUser1.Id, Arg.Is(x => x != OrganizationUserStatusType.Revoked)); + } + private static void RestoreUser_Setup( Organization organization, OrganizationUser? requestingOrganizationUser, diff --git a/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs b/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs index edd7a06fa7..eec6747c5f 100644 --- a/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs +++ b/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs @@ -17,6 +17,7 @@ using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; +using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; using GlobalSettings = Bit.Core.Settings.GlobalSettings; @@ -395,6 +396,87 @@ public class AuthRequestServiceTests user.Name); } + + [Theory, BitAutoData] + public async Task CreateAuthRequestAsync_AdminApproval_WithAdminNotifications_AndNoAdminEmails_ShouldNotSendNotificationEmails( + SutProvider sutProvider, + AuthRequestCreateRequestModel createModel, + User user, + OrganizationUser organizationUser1) + { + createModel.Type = AuthRequestType.AdminApproval; + user.Email = createModel.Email; + organizationUser1.UserId = user.Id; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DeviceApprovalRequestAdminNotifications) + .Returns(true); + + sutProvider.GetDependency() + .GetByEmailAsync(user.Email) + .Returns(user); + + sutProvider.GetDependency() + .DeviceType + .Returns(DeviceType.ChromeExtension); + + sutProvider.GetDependency() + .UserId + .Returns(user.Id); + + sutProvider.GetDependency() + .PasswordlessAuth.KnownDevicesOnly + .Returns(false); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns(new List + { + organizationUser1, + }); + + sutProvider.GetDependency() + .GetManyByMinimumRoleAsync(organizationUser1.OrganizationId, OrganizationUserType.Admin) + .Returns([]); + + sutProvider.GetDependency() + .GetManyDetailsByRoleAsync(organizationUser1.OrganizationId, OrganizationUserType.Custom) + .Returns([]); + + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(c => c.ArgAt(0)); + + var authRequest = await sutProvider.Sut.CreateAuthRequestAsync(createModel); + + Assert.Equal(organizationUser1.OrganizationId, authRequest.OrganizationId); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(Arg.Is(o => o.OrganizationId == organizationUser1.OrganizationId)); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(Arg.Any()); + + await sutProvider.GetDependency() + .Received(1) + .LogUserEventAsync(user.Id, EventType.User_RequestedDeviceApproval); + + await sutProvider.GetDependency() + .Received(0) + .SendDeviceApprovalRequestedNotificationEmailAsync( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + var expectedLogMessage = "There are no admin emails to send to."; + sutProvider.GetDependency>() + .Received(1) + .LogWarning(expectedLogMessage); + } + /// /// Story: When an is approved we want to update it in the database so it cannot have /// it's status changed again and we want to push a notification to let the user know of the approval. diff --git a/test/Core.Test/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategyTests.cs b/test/Core.Test/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategyTests.cs new file mode 100644 index 0000000000..dc40656275 --- /dev/null +++ b/test/Core.Test/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategyTests.cs @@ -0,0 +1,492 @@ +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Services.Implementations.AutomaticTax; + +[SutProviderCustomize] +public class BusinessUseAutomaticTaxStrategyTests +{ + [Theory] + [BitAutoData] + public void GetUpdateOptions_ReturnsNull_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled( + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(false); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.Null(actual); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_ReturnsNull_WhenSubscriptionDoesNotNeedUpdating( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "US", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.Null(actual); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.False(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForAmericanCustomers( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = "US", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.True(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List + { + new() + { + Country = "ES", + Type = "eu_vat", + Value = "ESZ8880999Z" + } + } + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.True(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_ThrowsArgumentNullException_WhenTaxIdsIsNull( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = null + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + Assert.Throws(() => sutProvider.Sut.GetUpdateOptions(subscription)); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds( + SutProvider sutProvider) + { + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List() + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.False(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsNothing_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + Customer = new Customer + { + Address = new() + { + Country = "US" + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(false); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.Null(options.AutomaticTax); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsNothing_WhenSubscriptionDoesNotNeedUpdating( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "US", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.Null(options.AutomaticTax); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.False(options.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForAmericanCustomers( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = "US", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.True(options.AutomaticTax!.Enabled); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List + { + new() + { + Country = "ES", + Type = "eu_vat", + Value = "ESZ8880999Z" + } + } + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.True(options.AutomaticTax!.Enabled); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_ThrowsArgumentNullException_WhenTaxIdsIsNull( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = null + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + Assert.Throws(() => sutProvider.Sut.SetUpdateOptions(options, subscription)); + } + + [Theory] + [BitAutoData] + public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds( + SutProvider sutProvider) + { + var options = new SubscriptionUpdateOptions(); + + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "ES", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List() + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + sutProvider.Sut.SetUpdateOptions(options, subscription); + + Assert.False(options.AutomaticTax!.Enabled); + } +} diff --git a/test/Core.Test/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategyTests.cs b/test/Core.Test/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategyTests.cs new file mode 100644 index 0000000000..2d50c9f75a --- /dev/null +++ b/test/Core.Test/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategyTests.cs @@ -0,0 +1,217 @@ +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Services.Implementations.AutomaticTax; + +[SutProviderCustomize] +public class PersonalUseAutomaticTaxStrategyTests +{ + [Theory] + [BitAutoData] + public void GetUpdateOptions_ReturnsNull_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled( + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(false); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.Null(actual); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_ReturnsNull_WhenSubscriptionDoesNotNeedUpdating( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Address = new Address + { + Country = "US", + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.Null(actual); + } + + [Theory] + [BitAutoData] + public void GetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid( + SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = true + }, + Customer = new Customer + { + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.False(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData("CA")] + [BitAutoData("ES")] + [BitAutoData("US")] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForAllCountries( + string country, SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = country + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.True(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData("CA")] + [BitAutoData("ES")] + [BitAutoData("US")] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds( + string country, SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = country, + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List + { + new() + { + Country = "ES", + Type = "eu_vat", + Value = "ESZ8880999Z" + } + } + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.True(actual.AutomaticTax.Enabled); + } + + [Theory] + [BitAutoData("CA")] + [BitAutoData("ES")] + [BitAutoData("US")] + public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds( + string country, SutProvider sutProvider) + { + var subscription = new Subscription + { + AutomaticTax = new SubscriptionAutomaticTax + { + Enabled = false + }, + Customer = new Customer + { + Address = new Address + { + Country = country + }, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + }, + TaxIds = new StripeList + { + Data = new List() + } + } + }; + + sutProvider.GetDependency() + .IsEnabled(Arg.Is(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates)) + .Returns(true); + + var actual = sutProvider.Sut.GetUpdateOptions(subscription); + + Assert.NotNull(actual); + Assert.True(actual.AutomaticTax.Enabled); + } +} diff --git a/test/Core.Test/Billing/Services/Implementations/AutomaticTaxFactoryTests.cs b/test/Core.Test/Billing/Services/Implementations/AutomaticTaxFactoryTests.cs new file mode 100644 index 0000000000..7d5c9c3a26 --- /dev/null +++ b/test/Core.Test/Billing/Services/Implementations/AutomaticTaxFactoryTests.cs @@ -0,0 +1,105 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services.Contracts; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; +using Bit.Core.Entities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Billing.Services.Implementations; + +[SutProviderCustomize] +public class AutomaticTaxFactoryTests +{ + [BitAutoData] + [Theory] + public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenSubscriberIsUser(SutProvider sut) + { + var parameters = new AutomaticTaxFactoryParameters(new User(), []); + + var actual = await sut.Sut.CreateAsync(parameters); + + Assert.IsType(actual); + } + + [BitAutoData] + [Theory] + public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenSubscriberIsOrganizationWithFamiliesAnnuallyPrice( + SutProvider sut) + { + var familiesPlan = new FamiliesPlan(); + var parameters = new AutomaticTaxFactoryParameters(new Organization(), [familiesPlan.PasswordManager.StripePlanId]); + + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) + .Returns(new FamiliesPlan()); + + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually2019)) + .Returns(new Families2019Plan()); + + var actual = await sut.Sut.CreateAsync(parameters); + + Assert.IsType(actual); + } + + [Theory] + [BitAutoData] + public async Task CreateAsync_ReturnsBusinessUseStrategy_WhenSubscriberIsOrganizationWithBusinessUsePrice( + EnterpriseAnnually plan, + SutProvider sut) + { + var parameters = new AutomaticTaxFactoryParameters(new Organization(), [plan.PasswordManager.StripePlanId]); + + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually)) + .Returns(new FamiliesPlan()); + + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == PlanType.FamiliesAnnually2019)) + .Returns(new Families2019Plan()); + + var actual = await sut.Sut.CreateAsync(parameters); + + Assert.IsType(actual); + } + + [Theory] + [BitAutoData] + public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenPlanIsMeantForPersonalUse(SutProvider sut) + { + var parameters = new AutomaticTaxFactoryParameters(PlanType.FamiliesAnnually); + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == parameters.PlanType.Value)) + .Returns(new FamiliesPlan()); + + var actual = await sut.Sut.CreateAsync(parameters); + + Assert.IsType(actual); + } + + [Theory] + [BitAutoData] + public async Task CreateAsync_ReturnsBusinessUseStrategy_WhenPlanIsMeantForBusinessUse(SutProvider sut) + { + var parameters = new AutomaticTaxFactoryParameters(PlanType.EnterpriseAnnually); + sut.GetDependency() + .GetPlanOrThrow(Arg.Is(p => p == parameters.PlanType.Value)) + .Returns(new EnterprisePlan(true)); + + var actual = await sut.Sut.CreateAsync(parameters); + + Assert.IsType(actual); + } + + public record EnterpriseAnnually : EnterprisePlan + { + public EnterpriseAnnually() : base(true) + { + } + } +} diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 5b7a2cc8bd..9e4be78787 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -3,10 +3,13 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Enums; using Bit.Core.Services; using Bit.Core.Settings; +using Bit.Core.Test.Billing.Stubs; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Braintree; @@ -1167,7 +1170,9 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId) + stripeAdapter.CustomerGetAsync( + provider.GatewayCustomerId, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1213,7 +1218,10 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId) + stripeAdapter.CustomerGetAsync( + provider.GatewayCustomerId, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")) + ) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1321,7 +1329,9 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().CustomerGetAsync( + provider.GatewayCustomerId, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1373,7 +1383,9 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().CustomerGetAsync( + provider.GatewayCustomerId, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1482,7 +1494,9 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().CustomerGetAsync( + provider.GatewayCustomerId, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer { Id = provider.GatewayCustomerId @@ -1561,6 +1575,37 @@ public class SubscriberServiceTests "Example Town", "NY"); + sutProvider.GetDependency() + .CustomerUpdateAsync( + Arg.Is(p => p == provider.GatewayCustomerId), + Arg.Is(options => + options.Address.Country == "US" && + options.Address.PostalCode == "12345" && + options.Address.Line1 == "123 Example St." && + options.Address.Line2 == null && + options.Address.City == "Example Town" && + options.Address.State == "NY")) + .Returns(new Customer + { + Id = provider.GatewayCustomerId, + Address = new Address + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Example St.", + Line2 = null, + City = "Example Town", + State = "NY" + }, + TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } + }); + + var subscription = new Subscription { Items = new StripeList() }; + sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) + .Returns(subscription); + sutProvider.GetDependency().CreateAsync(Arg.Any()) + .Returns(new FakeAutomaticTaxStrategy(true)); + await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( diff --git a/test/Core.Test/Billing/Stubs/FakeAutomaticTaxStrategy.cs b/test/Core.Test/Billing/Stubs/FakeAutomaticTaxStrategy.cs new file mode 100644 index 0000000000..253aead5c7 --- /dev/null +++ b/test/Core.Test/Billing/Stubs/FakeAutomaticTaxStrategy.cs @@ -0,0 +1,35 @@ +using Bit.Core.Billing.Services; +using Stripe; + +namespace Bit.Core.Test.Billing.Stubs; + +/// +/// Whether the subscription options will have automatic tax enabled or not. +/// +public class FakeAutomaticTaxStrategy( + bool isAutomaticTaxEnabled) : IAutomaticTaxStrategy +{ + public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription) + { + return new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled } + }; + } + + public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) + { + options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled }; + } + + public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription) + { + options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled }; + } + + public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options) + { + options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled }; + + } +} diff --git a/util/Migrator/DbScripts/2025-04-01_00_RecreateNotificationStatusView.sql b/util/Migrator/DbScripts/2025-04-01_00_RecreateNotificationStatusView.sql new file mode 100644 index 0000000000..727218f9ab --- /dev/null +++ b/util/Migrator/DbScripts/2025-04-01_00_RecreateNotificationStatusView.sql @@ -0,0 +1,25 @@ +-- Recreate the NotificationStatusView to include the Notification.TaskId column +CREATE OR ALTER VIEW [dbo].[NotificationStatusDetailsView] +AS +SELECT + N.[Id], + N.[Priority], + N.[Global], + N.[ClientType], + N.[UserId], + N.[OrganizationId], + N.[Title], + N.[Body], + N.[CreationDate], + N.[RevisionDate], + N.[TaskId], + NS.[UserId] AS [NotificationStatusUserId], + NS.[ReadDate], + NS.[DeletedDate] +FROM + [dbo].[Notification] AS N + LEFT JOIN + [dbo].[NotificationStatus] as NS +ON + N.[Id] = NS.[NotificationId] +GO