diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index 2debd521a5..d28f291599 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -228,6 +228,26 @@ public class RemoveOrganizationFromProviderCommandTests Id = "subscription_id" }); + sutProvider.GetDependency() + .When(x => x.SetCreateOptionsAsync( + Arg.Is(options => + options.Customer == organization.GatewayCustomerId && + options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && + options.DaysUntilDue == 30 && + options.Metadata["organizationId"] == organization.Id.ToString() && + options.OffSession == true && + options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId && + options.Items.First().Quantity == organization.Seats) + , Arg.Any())) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index c1da732d60..d727f0c7d9 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -975,11 +975,12 @@ public class ProviderBillingServiceTests { provider.GatewaySubscriptionId = null; - sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(new Customer + var customer = new Customer { Id = "customer_id", Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } - }); + }; + sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(customer); var providerPlans = new List { @@ -1017,6 +1018,19 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; + sutProvider.GetDependency() + .When(x => x.SetCreateOptionsAsync( + Arg.Is(options => + options.Customer == "customer_id") + , Arg.Is(p => p == customer))) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && diff --git a/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs b/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs index 4d93c0119a..dd8cf372e9 100644 --- a/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs +++ b/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs @@ -5,6 +5,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Migration.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Repositories; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; @@ -20,7 +21,8 @@ public class OrganizationMigrator( IMigrationTrackerCache migrationTrackerCache, IOrganizationRepository organizationRepository, IPricingClient pricingClient, - IStripeAdapter stripeAdapter) : IOrganizationMigrator + IStripeAdapter stripeAdapter, + IOrganizationAutomaticTaxStrategy automaticTaxStrategy) : IOrganizationMigrator { private const string _cancellationComment = "Cancelled as part of provider migration to Consolidated Billing"; @@ -231,10 +233,6 @@ public class OrganizationMigrator( var subscriptionCreateOptions = new SubscriptionCreateOptions { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }, Customer = customer.Id, CollectionMethod = collectionMethod, DaysUntilDue = collectionMethod == StripeConstants.CollectionMethod.SendInvoice ? 30 : null, @@ -248,6 +246,8 @@ public class OrganizationMigrator( TrialPeriodDays = plan.TrialPeriodDays }; + await automaticTaxStrategy.SetCreateOptionsAsync(subscriptionCreateOptions, customer); + var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); organization.GatewaySubscriptionId = subscription.Id; diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/OrganizationAutomaticTaxStrategy.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/OrganizationAutomaticTaxStrategy.cs index 5852c7256c..be97337e8f 100644 --- a/src/Core/Billing/Services/Implementations/AutomaticTax/OrganizationAutomaticTaxStrategy.cs +++ b/src/Core/Billing/Services/Implementations/AutomaticTax/OrganizationAutomaticTaxStrategy.cs @@ -74,42 +74,33 @@ public class OrganizationAutomaticTaxStrategy( private async Task IsEnabledAsync(Subscription subscription) { - if (subscription.AutomaticTax.Enabled || - !subscription.Customer.HasBillingLocation() || - await IsNonTaxableNonUsBusinessUseSubscriptionAsync(subscription)) + bool shouldBeEnabled; + if (subscription.Customer.HasBillingLocation() && subscription.Customer.Address.Country == "US") { - return null; + shouldBeEnabled = true; + } + else + { + var familyPriceIds = await _familyPriceIdsTask.Value; + shouldBeEnabled = subscription.Items.Select(item => item.Price.Id).Intersect(familyPriceIds).Any(); } - return !await IsNonTaxableNonUsBusinessUseSubscriptionAsync(subscription); - } + if (subscription.AutomaticTax.Enabled != shouldBeEnabled) + { + return shouldBeEnabled; + } - private async Task IsNonTaxableNonUsBusinessUseSubscriptionAsync(Subscription subscription) - { - var familyPriceIds = await _familyPriceIdsTask.Value; - - return subscription.Customer.Address.Country != "US" && - !subscription.Items.Select(item => item.Price.Id).Intersect(familyPriceIds).Any() && - !subscription.Customer.TaxIds.Any(); + return null; } private async Task IsEnabledAsync(SubscriptionCreateOptions options, Customer customer) { - if (!customer.HasBillingLocation() || - await IsNonTaxableNonUsBusinessUseSubscriptionAsync(options, customer)) + if (customer.HasBillingLocation() && customer.Address.Country == "US") { - return null; + return true; } - return !await IsNonTaxableNonUsBusinessUseSubscriptionAsync(options, customer); - } - - private async Task IsNonTaxableNonUsBusinessUseSubscriptionAsync(SubscriptionCreateOptions options, Customer customer) - { var familyPriceIds = await _familyPriceIdsTask.Value; - - return customer.Address.Country != "US" && - !options.Items.Select(item => item.Price).Intersect(familyPriceIds).Any() && - !customer.TaxIds.Any(); + return options.Items.Select(item => item.Price).Intersect(familyPriceIds).Any(); } } diff --git a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs index 8b773f1cef..df5322fd73 100644 --- a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs +++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs @@ -30,7 +30,8 @@ public class OrganizationBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - ITaxService taxService) : IOrganizationBillingService + ITaxService taxService, + IOrganizationAutomaticTaxStrategy organizationAutomaticTaxStrategy) : IOrganizationBillingService { public async Task Finalize(OrganizationSale sale) { @@ -380,10 +381,6 @@ public class OrganizationBillingService( var subscriptionCreateOptions = new SubscriptionCreateOptions { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = customerHasTaxInfo - }, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, @@ -395,6 +392,8 @@ public class OrganizationBillingService( TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays }; + await organizationAutomaticTaxStrategy.SetCreateOptionsAsync(subscriptionCreateOptions, customer); + return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); } diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index b2dca19e80..451f0df438 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -24,7 +24,9 @@ public class SubscriberService( ILogger logger, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ITaxService taxService) : ISubscriberService + ITaxService taxService, + IIndividualAutomaticTaxStrategy individualAutomaticTaxStrategy, + IOrganizationAutomaticTaxStrategy organizationAutomaticTaxStrategy) : ISubscriberService { public async Task CancelSubscription( ISubscriber subscriber, @@ -597,7 +599,7 @@ public class SubscriberService( Expand = ["subscriptions", "tax", "tax_ids"] }); - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + customer = await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Address = new AddressOptions { @@ -661,21 +663,17 @@ public class SubscriberService( } } - if (SubscriberIsEligibleForAutomaticTax(subscriber, customer)) + if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) { - await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, - new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + var automaticTaxOptions = subscriber.IsUser() + ? await individualAutomaticTaxStrategy.GetUpdateOptionsAsync(subscription) + : await organizationAutomaticTaxStrategy.GetUpdateOptionsAsync(subscription); + if (automaticTaxOptions?.AutomaticTax?.Enabled != null) + { + await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions); + } } - - return; - - bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer) - => !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) && - (localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) && - localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; } public async Task VerifyBankAccount( diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 5b7a2cc8bd..94dc85aa01 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -1561,6 +1561,31 @@ public class SubscriberServiceTests "Example Town", "NY"); + sutProvider.GetDependency() + .CustomerUpdateAsync( + Arg.Is(p => p == provider.GatewayCustomerId), + Arg.Is(options => + options.Address.Country == "US" && + options.Address.PostalCode == "12345" && + options.Address.Line1 == "123 Example St." && + options.Address.Line2 == null && + options.Address.City == "Example Town" && + options.Address.State == "NY")) + .Returns(new Customer + { + Id = provider.GatewayCustomerId, + Address = new Address + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Example St.", + Line2 = null, + City = "Example Town", + State = "NY" + }, + TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } + }); + await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is(