diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs index 22a2e93642..35a00f4253 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs @@ -8,13 +8,10 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Services; -using Bit.Core.Billing.Tax.Services.Implementations; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; -using Microsoft.Extensions.DependencyInjection; using Stripe; namespace Bit.Commercial.Core.AdminConsole.Providers; @@ -24,7 +21,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv private readonly IEventService _eventService; private readonly IMailService _mailService; private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; private readonly IProviderOrganizationRepository _providerOrganizationRepository; private readonly IStripeAdapter _stripeAdapter; private readonly IFeatureService _featureService; @@ -32,26 +28,22 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv private readonly ISubscriberService _subscriberService; private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery; private readonly IPricingClient _pricingClient; - private readonly IAutomaticTaxStrategy _automaticTaxStrategy; public RemoveOrganizationFromProviderCommand( IEventService eventService, IMailService mailService, IOrganizationRepository organizationRepository, - IOrganizationService organizationService, IProviderOrganizationRepository providerOrganizationRepository, IStripeAdapter stripeAdapter, IFeatureService featureService, IProviderBillingService providerBillingService, ISubscriberService subscriberService, IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, - IPricingClient pricingClient, - [FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy) + IPricingClient pricingClient) { _eventService = eventService; _mailService = mailService; _organizationRepository = organizationRepository; - _organizationService = organizationService; _providerOrganizationRepository = providerOrganizationRepository; _stripeAdapter = stripeAdapter; _featureService = featureService; @@ -59,7 +51,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv _subscriberService = subscriberService; _hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery; _pricingClient = pricingClient; - _automaticTaxStrategy = automaticTaxStrategy; } public async Task RemoveOrganizationFromProvider( @@ -77,7 +68,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync( providerOrganization.OrganizationId, - Array.Empty(), + [], includeProvider: false)) { throw new BadRequestException("Organization must have at least one confirmed owner."); @@ -102,7 +93,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv /// /// When a client organization is unlinked from a provider, we have to check if they're Stripe-enabled /// and, if they are, we remove their MSP discount and set their Subscription to `send_invoice`. This is because - /// the provider's payment method will be removed from their Stripe customer causing ensuing charges to fail. Lastly, + /// the provider's payment method will be removed from their Stripe customer, causing ensuing charges to fail. Lastly, /// we email the organization owners letting them know they need to add a new payment method. /// private async Task ResetOrganizationBillingAsync( @@ -142,15 +133,18 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }] }; - if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + var setNonUSBusinessUseToReverseCharge = _featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge) { - _automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } - else + else if (customer.HasRecognizedTaxLocation()) { - subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { - Enabled = true + Enabled = customer.Address.Country == "US" || + customer.TaxIds.Any() }; } @@ -187,7 +181,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv await _mailService.SendProviderUpdatePaymentMethod( organization.Id, organization.Name, - provider.Name, + provider.Name!, organizationOwnerEmails); } } diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index fe6b8d4617..c8d6505183 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -18,7 +18,6 @@ using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Services; -using Bit.Core.Billing.Tax.Services.Implementations; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -27,7 +26,6 @@ using Bit.Core.Services; using Bit.Core.Settings; using Braintree; using CsvHelper; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; @@ -52,8 +50,7 @@ public class ProviderBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - ITaxService taxService, - [FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy) + ITaxService taxService) : IProviderBillingService { public async Task AddExistingOrganization( @@ -128,7 +125,7 @@ public class ProviderBillingService( /* * We have to scale the provider's seats before the ProviderOrganization - * row is inserted so the added organization's seats don't get double counted. + * row is inserted so the added organization's seats don't get double-counted. */ await ScaleSeats(provider, organization.PlanType, organization.Seats!.Value); @@ -236,7 +233,7 @@ public class ProviderBillingService( var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions { - Expand = ["tax_ids"] + Expand = ["tax", "tax_ids"] }); var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); @@ -284,6 +281,13 @@ public class ProviderBillingService( ] }; + var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge && providerCustomer.Address is not { Country: "US" }) + { + customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse; + } + var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); organization.GatewayCustomerId = customer.Id; @@ -520,6 +524,13 @@ public class ProviderBillingService( } }; + var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge && taxInfo.BillingAddressCountry != "US") + { + options.TaxExempt = StripeConstants.TaxExempt.Reverse; + } + if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber)) { var taxIdType = taxService.GetStripeTaxCode( @@ -531,6 +542,7 @@ public class ProviderBillingService( logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber); + throw new BadRequestException("billingTaxIdTypeInferenceError"); } @@ -718,14 +730,21 @@ public class ProviderBillingService( TrialPeriodDays = trialPeriodDays }; - if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) - { - automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); - } - else + var setNonUSBusinessUseToReverseCharge = + featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge) { subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } + else if (customer.HasRecognizedTaxLocation()) + { + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = customer.Address.Country == "US" || + customer.TaxIds.Any() + }; + } try { diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index b450bf5d7f..dd40d7d943 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -1,4 +1,5 @@ using Bit.Commercial.Core.AdminConsole.Providers; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; @@ -8,7 +9,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -224,31 +224,115 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Description == string.Empty && + options.Email == organization.BillingEmail && + options.Expand[0] == "tax" && + options.Expand[1] == "tax_ids")).Returns(new Customer + { + Id = "customer_id", + Address = new Address + { + Country = "US" + } + }); + stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription { Id = "subscription_id" }); - sutProvider.GetDependency() - .When(x => x.SetCreateOptions( - Arg.Is(options => - options.Customer == organization.GatewayCustomerId && - options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && - options.DaysUntilDue == 30 && - options.Metadata["organizationId"] == organization.Id.ToString() && - options.OffSession == true && - options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && - options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId && - options.Items.First().Quantity == organization.Seats) - , Arg.Any())) - .Do(x => + await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); + + await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => + options.Customer == organization.GatewayCustomerId && + options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && + options.DaysUntilDue == 30 && + options.AutomaticTax.Enabled == true && + options.Metadata["organizationId"] == organization.Id.ToString() && + options.OffSession == true && + options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId && + options.Items.First().Quantity == organization.Seats)); + + await sutProvider.GetDependency().Received(1) + .ScaleSeats(provider, organization.PlanType, -organization.Seats ?? 0); + + await organizationRepository.Received(1).ReplaceAsync(Arg.Is( + org => + org.BillingEmail == "a@example.com" && + org.GatewaySubscriptionId == "subscription_id" && + org.Status == OrganizationStatusType.Created)); + + await sutProvider.GetDependency().Received(1) + .DeleteAsync(providerOrganization); + + await sutProvider.GetDependency().Received(1) + .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); + + await sutProvider.GetDependency().Received(1) + .SendProviderUpdatePaymentMethod( + organization.Id, + organization.Name, + provider.Name, + Arg.Is>(emails => emails.FirstOrDefault() == "a@example.com")); + } + + [Theory, BitAutoData] + public async Task RemoveOrganizationFromProvider_OrganizationStripeEnabled_ConsolidatedBilling_ReverseCharge_MakesCorrectInvocations( + Provider provider, + ProviderOrganization providerOrganization, + Organization organization, + SutProvider sutProvider) + { + provider.Status = ProviderStatusType.Billable; + + providerOrganization.ProviderId = provider.Id; + + organization.Status = OrganizationStatusType.Managed; + + organization.PlanType = PlanType.TeamsMonthly; + + var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + + sutProvider.GetDependency().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan); + + sutProvider.GetDependency().HasConfirmedOwnersExceptAsync( + providerOrganization.OrganizationId, + [], + includeProvider: false) + .Returns(true); + + var organizationRepository = sutProvider.GetDependency(); + + organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns([ + "a@example.com", + "b@example.com" + ]); + + var stripeAdapter = sutProvider.GetDependency(); + + stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Description == string.Empty && + options.Email == organization.BillingEmail && + options.Expand[0] == "tax" && + options.Expand[1] == "tax_ids")).Returns(new Customer { - x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + Id = "customer_id", + Address = new Address { - Enabled = true - }; + Country = "US" + } }); + stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + { + Id = "subscription_id" + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true); + await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index 2199bc4bfe..92094d026e 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -262,7 +262,7 @@ public class ProviderBillingServiceTests }; sutProvider.GetDependency().GetCustomerOrThrow(provider, Arg.Is( - options => options.Expand.FirstOrDefault() == "tax_ids")) + options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids"))) .Returns(providerCustomer); sutProvider.GetDependency().BaseServiceUri @@ -312,6 +312,91 @@ public class ProviderBillingServiceTests org => org.GatewayCustomerId == "customer_id")); } + [Theory, BitAutoData] + public async Task CreateCustomer_ForClientOrg_ReverseCharge_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = null; + organization.Name = "Name"; + organization.BusinessName = "BusinessName"; + + var providerCustomer = new Customer + { + Address = new Address + { + Country = "CA", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Unit 4", + City = "Fake Town", + State = "Fake State" + }, + TaxIds = new StripeList + { + Data = + [ + new TaxId { Type = "TYPE", Value = "VALUE" } + ] + } + }; + + sutProvider.GetDependency().GetCustomerOrThrow(provider, Arg.Is( + options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids"))) + .Returns(providerCustomer); + + sutProvider.GetDependency().BaseServiceUri + .Returns(new Bit.Core.Settings.GlobalSettings.BaseServiceUriSettings(new Bit.Core.Settings.GlobalSettings()) + { + CloudRegion = "US" + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true); + + sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( + options => + options.Address.Country == providerCustomer.Address.Country && + options.Address.PostalCode == providerCustomer.Address.PostalCode && + options.Address.Line1 == providerCustomer.Address.Line1 && + options.Address.Line2 == providerCustomer.Address.Line2 && + options.Address.City == providerCustomer.Address.City && + options.Address.State == providerCustomer.Address.State && + options.Name == organization.DisplayName() && + options.Description == $"{provider.Name} Client Organization" && + options.Email == provider.BillingEmail && + options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" && + options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" && + options.Metadata["region"] == "US" && + options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type && + options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value && + options.TaxExempt == StripeConstants.TaxExempt.Reverse)) + .Returns(new Customer { Id = "customer_id" }); + + await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); + + await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( + options => + options.Address.Country == providerCustomer.Address.Country && + options.Address.PostalCode == providerCustomer.Address.PostalCode && + options.Address.Line1 == providerCustomer.Address.Line1 && + options.Address.Line2 == providerCustomer.Address.Line2 && + options.Address.City == providerCustomer.Address.City && + options.Address.State == providerCustomer.Address.State && + options.Name == organization.DisplayName() && + options.Description == $"{provider.Name} Client Organization" && + options.Email == provider.BillingEmail && + options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" && + options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" && + options.Metadata["region"] == "US" && + options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type && + options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value)); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.GatewayCustomerId == "customer_id")); + } + #endregion #region GenerateClientInvoiceReport @@ -1182,6 +1267,62 @@ public class ProviderBillingServiceTests Assert.Equivalent(expected, actual); } + [Theory, BitAutoData] + public async Task SetupCustomer_WithCard_ReverseCharge_Success( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var expected = new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.PaymentMethod == tokenizedPaymentSource.Token && + o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber && + o.TaxExempt == StripeConstants.TaxExempt.Reverse)) + .Returns(expected); + + var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + + Assert.Equivalent(expected, actual); + } + [Theory, BitAutoData] public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid( SutProvider sutProvider, @@ -1307,7 +1448,7 @@ public class ProviderBillingServiceTests .Returns(new Customer { Id = "customer_id", - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + Address = new Address { Country = "US" } }); var providerPlans = new List @@ -1359,7 +1500,7 @@ public class ProviderBillingServiceTests var customer = new Customer { Id = "customer_id", - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + Address = new Address { Country = "US" } }; sutProvider.GetDependency() .GetCustomerOrThrow( @@ -1399,19 +1540,6 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency() - .When(x => x.SetCreateOptions( - Arg.Is(options => - options.Customer == "customer_id") - , Arg.Is(p => p == customer))) - .Do(x => - { - x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }; - }); - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && @@ -1443,11 +1571,11 @@ public class ProviderBillingServiceTests var customer = new Customer { Id = "customer_id", + Address = new Address { Country = "US" }, InvoiceSettings = new CustomerInvoiceSettings { DefaultPaymentMethodId = "pm_123" - }, - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + } }; sutProvider.GetDependency() @@ -1488,19 +1616,6 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency() - .When(x => x.SetCreateOptions( - Arg.Is(options => - options.Customer == "customer_id") - , Arg.Is(p => p == customer))) - .Do(x => - { - x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }; - }); - sutProvider.GetDependency() .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); @@ -1536,9 +1651,9 @@ public class ProviderBillingServiceTests var customer = new Customer { Id = "customer_id", + Address = new Address { Country = "US" }, InvoiceSettings = new CustomerInvoiceSettings(), - Metadata = new Dictionary(), - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + Metadata = new Dictionary() }; sutProvider.GetDependency() @@ -1579,19 +1694,6 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency() - .When(x => x.SetCreateOptions( - Arg.Is(options => - options.Customer == "customer_id") - , Arg.Is(p => p == customer))) - .Do(x => - { - x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }; - }); - sutProvider.GetDependency() .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); @@ -1646,12 +1748,15 @@ public class ProviderBillingServiceTests var customer = new Customer { Id = "customer_id", + Address = new Address + { + Country = "US" + }, InvoiceSettings = new CustomerInvoiceSettings(), Metadata = new Dictionary { ["btCustomerId"] = "braintree_customer_id" - }, - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + } }; sutProvider.GetDependency() @@ -1692,22 +1797,92 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency() - .When(x => x.SetCreateOptions( - Arg.Is(options => - options.Customer == "customer_id") - , Arg.Is(p => p == customer))) - .Do(x => + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && + sub.Customer == "customer_id" && + sub.DaysUntilDue == null && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + sub.TrialPeriodDays == 14)).Returns(expected); + + var actual = await sutProvider.Sut.SetupSubscription(provider); + + Assert.Equivalent(expected, actual); + } + + [Theory, BitAutoData] + public async Task SetupSubscription_ReverseCharge_Succeeds( + SutProvider sutProvider, + Provider provider) + { + provider.Type = ProviderType.Msp; + provider.GatewaySubscriptionId = null; + + var customer = new Customer + { + Id = "customer_id", + Address = new Address { Country = "CA" }, + InvoiceSettings = new CustomerInvoiceSettings { - x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }; - }); + DefaultPaymentMethodId = "pm_123" + } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer); + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + }, + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + } + }; + + foreach (var plan in providerPlans) + { + sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) + .Returns(StaticStore.GetPlan(plan.PlanType)); + } + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; sutProvider.GetDependency() .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true); + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && diff --git a/src/Api/Billing/Controllers/OrganizationsController.cs b/src/Api/Billing/Controllers/OrganizationsController.cs index 510f6c2835..bd5ab8cef4 100644 --- a/src/Api/Billing/Controllers/OrganizationsController.cs +++ b/src/Api/Billing/Controllers/OrganizationsController.cs @@ -109,28 +109,6 @@ public class OrganizationsController( return license; } - [HttpPost("{id:guid}/payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostPayment(Guid id, [FromBody] PaymentRequestModel model) - { - if (!await currentContext.EditPaymentMethods(id)) - { - throw new NotFoundException(); - } - - await organizationService.ReplacePaymentMethodAsync(id, model.PaymentToken, - model.PaymentMethodType.Value, new TaxInfo - { - BillingAddressLine1 = model.Line1, - BillingAddressLine2 = model.Line2, - BillingAddressState = model.State, - BillingAddressCity = model.City, - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - TaxIdNumber = model.TaxId, - }); - } - [HttpPost("{id:guid}/upgrade")] [SelfHosted(NotSelfHostedOnly = true)] public async Task PostUpgrade(Guid id, [FromBody] OrganizationUpgradeRequestModel model) diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index 4c27098f38..e31d1dceb7 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -1,11 +1,11 @@ using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Services.Contracts; -using Bit.Core.Billing.Tax.Services; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -25,8 +25,7 @@ public class UpcomingInvoiceHandler( IStripeEventService stripeEventService, IStripeEventUtilityService stripeEventUtilityService, IUserRepository userRepository, - IValidateSponsorshipCommand validateSponsorshipCommand, - IAutomaticTaxFactory automaticTaxFactory) + IValidateSponsorshipCommand validateSponsorshipCommand) : IUpcomingInvoiceHandler { public async Task HandleAsync(Event parsedEvent) @@ -46,6 +45,8 @@ public class UpcomingInvoiceHandler( var (organizationId, userId, providerId) = stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata); + var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + if (organizationId.HasValue) { var organization = await organizationRepository.GetByIdAsync(organizationId.Value); @@ -55,7 +56,7 @@ public class UpcomingInvoiceHandler( return; } - await TryEnableAutomaticTaxAsync(subscription); + await AlignOrganizationTaxConcernsAsync(organization, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge); var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); @@ -100,7 +101,25 @@ public class UpcomingInvoiceHandler( return; } - await TryEnableAutomaticTaxAsync(subscription); + if (!subscription.AutomaticTax.Enabled && subscription.Customer.HasRecognizedTaxLocation()) + { + try + { + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set user's ({UserID}) subscription to automatic tax while processing event with ID {EventID}", + user.Id, + parsedEvent.Id); + } + } if (user.Premium) { @@ -116,7 +135,7 @@ public class UpcomingInvoiceHandler( return; } - await TryEnableAutomaticTaxAsync(subscription); + await AlignProviderTaxConcernsAsync(provider, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge); await SendUpcomingInvoiceEmailsAsync(new List { provider.BillingEmail }, invoice); } @@ -139,50 +158,123 @@ public class UpcomingInvoiceHandler( } } - private async Task TryEnableAutomaticTaxAsync(Subscription subscription) + private async Task AlignOrganizationTaxConcernsAsync( + Organization organization, + Subscription subscription, + string eventId, + bool setNonUSBusinessUseToReverseCharge) { - if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) - { - var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id)); - var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); - var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription); + var nonUSBusinessUse = + organization.PlanType.GetProductTier() != ProductTierType.Families && + subscription.Customer.Address.Country != "US"; - if (updateOptions == null) + bool setAutomaticTaxToEnabled; + + if (setNonUSBusinessUseToReverseCharge) + { + if (nonUSBusinessUse && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse) { - return; + try + { + await stripeFacade.UpdateCustomer(subscription.CustomerId, + new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}", + organization.Id, + eventId); + } } - await stripeFacade.UpdateSubscription(subscription.Id, updateOptions); - return; + setAutomaticTaxToEnabled = true; } - - if (subscription.AutomaticTax.Enabled || - !subscription.Customer.HasBillingLocation() || - await IsNonTaxableNonUSBusinessUseSubscription(subscription)) + else { - return; + setAutomaticTaxToEnabled = + subscription.Customer.HasRecognizedTaxLocation() && + (subscription.Customer.Address.Country == "US" || + (nonUSBusinessUse && subscription.Customer.TaxIds.Any())); } - await stripeFacade.UpdateSubscription(subscription.Id, - new SubscriptionUpdateOptions + if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled) + { + try { - DefaultTaxRates = [], - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set organization's ({OrganizationID}) subscription to automatic tax while processing event with ID {EventID}", + organization.Id, + eventId); + } + } + } - return; + private async Task AlignProviderTaxConcernsAsync( + Provider provider, + Subscription subscription, + string eventId, + bool setNonUSBusinessUseToReverseCharge) + { + bool setAutomaticTaxToEnabled; - async Task IsNonTaxableNonUSBusinessUseSubscription(Subscription localSubscription) + if (setNonUSBusinessUseToReverseCharge) { - var familyPriceIds = (await Task.WhenAll( - pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), - pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually))) - .Select(plan => plan.PasswordManager.StripePlanId); + if (subscription.Customer.Address.Country != "US" && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse) + { + try + { + await stripeFacade.UpdateCustomer(subscription.CustomerId, + new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}", + provider.Id, + eventId); + } + } - return localSubscription.Customer.Address.Country != "US" && - localSubscription.Metadata.ContainsKey(StripeConstants.MetadataKeys.OrganizationId) && - !localSubscription.Items.Select(item => item.Price.Id).Intersect(familyPriceIds).Any() && - !localSubscription.Customer.TaxIds.Any(); + setAutomaticTaxToEnabled = true; + } + else + { + setAutomaticTaxToEnabled = + subscription.Customer.HasRecognizedTaxLocation() && + (subscription.Customer.Address.Country == "US" || + subscription.Customer.TaxIds.Any()); + } + + if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled) + { + try + { + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set provider's ({ProviderID}) subscription to automatic tax while processing event with ID {EventID}", + provider.Id, + eventId); + } } } } diff --git a/src/Core/AdminConsole/Services/IOrganizationService.cs b/src/Core/AdminConsole/Services/IOrganizationService.cs index 1e53be734e..8baad23f65 100644 --- a/src/Core/AdminConsole/Services/IOrganizationService.cs +++ b/src/Core/AdminConsole/Services/IOrganizationService.cs @@ -11,8 +11,6 @@ namespace Bit.Core.Services; public interface IOrganizationService { - Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType, - TaxInfo taxInfo); Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null); Task ReinstateSubscriptionAsync(Guid organizationId); Task AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb); diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 1ced923b45..4e9d9bdb8a 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -144,27 +144,6 @@ public class OrganizationService : IOrganizationService _sendOrganizationInvitesCommand = sendOrganizationInvitesCommand; } - public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, - PaymentMethodType paymentMethodType, TaxInfo taxInfo) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - await _paymentService.SaveTaxInfoAsync(organization, taxInfo); - var updated = await _paymentService.UpdatePaymentMethodAsync( - organization, - paymentMethodType, - paymentToken, - taxInfo); - if (updated) - { - await ReplaceAndUpdateCacheAsync(organization); - } - } - public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null) { var organization = await GetOrgById(organizationId); diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index c3e3ec6c30..28f4dea4b2 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -2,10 +2,6 @@ public static class StripeConstants { - public static class Prices - { - public const string StoragePlanPersonal = "personal-storage-gb-annually"; - } public static class AutomaticTaxStatus { public const string Failed = "failed"; @@ -69,6 +65,11 @@ public static class StripeConstants public const string USBankAccount = "us_bank_account"; } + public static class Prices + { + public const string StoragePlanPersonal = "personal-storage-gb-annually"; + } + public static class ProrationBehavior { public const string AlwaysInvoice = "always_invoice"; @@ -88,6 +89,13 @@ public static class StripeConstants public const string Paused = "paused"; } + public static class TaxExempt + { + public const string Exempt = "exempt"; + public const string None = "none"; + public const string Reverse = "reverse"; + } + public static class ValidateTaxLocationTiming { public const string Deferred = "deferred"; diff --git a/src/Core/Billing/Extensions/CustomerExtensions.cs b/src/Core/Billing/Extensions/CustomerExtensions.cs index 3e0c1ea0fb..aa22331f7c 100644 --- a/src/Core/Billing/Extensions/CustomerExtensions.cs +++ b/src/Core/Billing/Extensions/CustomerExtensions.cs @@ -15,12 +15,7 @@ public static class CustomerExtensions } }; - /// - /// Determines if a Stripe customer supports automatic tax - /// - /// - /// - public static bool HasTaxLocationVerified(this Customer customer) => + public static bool HasRecognizedTaxLocation(this Customer customer) => customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation; public static decimal GetBillingBalance(this Customer customer) diff --git a/src/Core/Billing/Extensions/SubscriptionUpdateOptionsExtensions.cs b/src/Core/Billing/Extensions/SubscriptionUpdateOptionsExtensions.cs index d70af78fa8..22a715733b 100644 --- a/src/Core/Billing/Extensions/SubscriptionUpdateOptionsExtensions.cs +++ b/src/Core/Billing/Extensions/SubscriptionUpdateOptionsExtensions.cs @@ -22,7 +22,7 @@ public static class SubscriptionUpdateOptionsExtensions } // We might only need to check the automatic tax status. - if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country)) + if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country)) { return false; } diff --git a/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs b/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs index 88df5638c9..d00b5b46a4 100644 --- a/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs +++ b/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs @@ -22,7 +22,7 @@ public static class UpcomingInvoiceOptionsExtensions } // We might only need to check the automatic tax status. - if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country)) + if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country)) { return false; } diff --git a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs index 20f6105c2a..95df34dfd4 100644 --- a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs +++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs @@ -1,11 +1,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; @@ -35,16 +35,15 @@ public class OrganizationBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - ITaxService taxService, - IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService + ITaxService taxService) : IOrganizationBillingService { public async Task Finalize(OrganizationSale sale) { var (organization, customerSetup, subscriptionSetup) = sale; var customer = string.IsNullOrEmpty(organization.GatewayCustomerId) && customerSetup != null - ? await CreateCustomerAsync(organization, customerSetup) - : await subscriberService.GetCustomerOrThrow(organization, new CustomerGetOptions { Expand = ["tax", "tax_ids"] }); + ? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType) + : await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup); var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup); @@ -121,7 +120,8 @@ public class OrganizationBillingService( subscription.CurrentPeriodEnd); } - public async Task UpdatePaymentMethod( + public async Task + UpdatePaymentMethod( Organization organization, TokenizedPaymentSource tokenizedPaymentSource, TaxInformation taxInformation) @@ -151,8 +151,11 @@ public class OrganizationBillingService( private async Task CreateCustomerAsync( Organization organization, - CustomerSetup customerSetup) + CustomerSetup customerSetup, + PlanType? updatedPlanType = null) { + var planType = updatedPlanType ?? organization.PlanType; + var displayName = organization.DisplayName(); var customerCreateOptions = new CustomerCreateOptions @@ -212,13 +215,24 @@ public class OrganizationBillingService( City = customerSetup.TaxInformation.City, PostalCode = customerSetup.TaxInformation.PostalCode, State = customerSetup.TaxInformation.State, - Country = customerSetup.TaxInformation.Country, + Country = customerSetup.TaxInformation.Country }; + customerCreateOptions.Tax = new CustomerTaxOptions { ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately }; + var setNonUSBusinessUseToReverseCharge = + featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge && + planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families && + customerSetup.TaxInformation.Country != "US") + { + customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse; + } + if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId)) { var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country, @@ -399,21 +413,68 @@ public class OrganizationBillingService( TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays }; - if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + var setNonUSBusinessUseToReverseCharge = + featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge) { - var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType); - var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); - automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } - else + else if (customer.HasRecognizedTaxLocation()) { - subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions(); - subscriptionCreateOptions.AutomaticTax.Enabled = customer.HasBillingLocation(); + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = + subscriptionSetup.PlanType.GetProductTier() == ProductTierType.Families || + customer.Address.Country == "US" || + customer.TaxIds.Any() + }; } return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); } + private async Task GetCustomerWhileEnsuringCorrectTaxExemptionAsync( + Organization organization, + SubscriptionSetup subscriptionSetup) + { + var customer = await subscriberService.GetCustomerOrThrow(organization, + new CustomerGetOptions { Expand = ["tax", "tax_ids"] }); + + var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (!setNonUSBusinessUseToReverseCharge || subscriptionSetup.PlanType.GetProductTier() is + not (ProductTierType.Teams or + ProductTierType.TeamsStarter or + ProductTierType.Enterprise)) + { + return customer; + } + + List expansions = ["tax", "tax_ids"]; + + customer = customer switch + { + { Address.Country: not "US", TaxExempt: not StripeConstants.TaxExempt.Reverse } => await + stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions + { + Expand = expansions, + TaxExempt = StripeConstants.TaxExempt.Reverse + }), + { Address.Country: "US", TaxExempt: StripeConstants.TaxExempt.Reverse } => await + stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions + { + Expand = expansions, + TaxExempt = StripeConstants.TaxExempt.None + }), + _ => customer + }; + + return customer; + } + private async Task IsEligibleForSelfHostAsync( Organization organization) { diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index 1b845e93f1..7496157aaa 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -3,8 +3,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Billing.Tax.Services; -using Bit.Core.Billing.Tax.Services.Implementations; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -12,7 +10,6 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Braintree; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; using Customer = Stripe.Customer; @@ -24,20 +21,18 @@ using static Utilities; public class PremiumUserBillingService( IBraintreeGateway braintreeGateway, - IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - IUserRepository userRepository, - [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService + IUserRepository userRepository) : IPremiumUserBillingService { public async Task Credit(User user, decimal amount) { var customer = await subscriberService.GetCustomer(user); - // Negative credit represents a balance and all Stripe denomination is in cents. + // Negative credit represents a balance, and all Stripe denomination is in cents. var credit = (long)(amount * -100); if (customer == null) @@ -184,7 +179,7 @@ public class PremiumUserBillingService( City = customerSetup.TaxInformation.City, PostalCode = customerSetup.TaxInformation.PostalCode, State = customerSetup.TaxInformation.State, - Country = customerSetup.TaxInformation.Country, + Country = customerSetup.TaxInformation.Country }, Description = user.Name, Email = user.Email, @@ -324,6 +319,10 @@ public class PremiumUserBillingService( var subscriptionCreateOptions = new SubscriptionCreateOptions { + AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, @@ -337,18 +336,6 @@ public class PremiumUserBillingService( OffSession = true }; - if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) - { - automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); - } - else - { - subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported, - }; - } - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); if (usingPayPal) @@ -380,7 +367,7 @@ public class PremiumUserBillingService( City = taxInformation.City, PostalCode = taxInformation.PostalCode, State = taxInformation.State, - Country = taxInformation.Country, + Country = taxInformation.Country }, Expand = ["tax"], Tax = new CustomerTaxOptions diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 10247cdf92..75a1bf76ec 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -1,7 +1,10 @@ -using Bit.Core.Billing.Caches; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Services; using Bit.Core.Entities; @@ -28,8 +31,7 @@ public class SubscriberService( ILogger logger, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ITaxService taxService, - IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService + ITaxService taxService) : ISubscriberService { public async Task CancelSubscription( ISubscriber subscriber, @@ -128,7 +130,7 @@ public class SubscriberService( [subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion }, Email = subscriber.BillingEmailAddress(), - PaymentMethodNonce = paymentMethodNonce, + PaymentMethodNonce = paymentMethodNonce }); if (customerResult.IsSuccess()) @@ -482,7 +484,7 @@ public class SubscriberService( var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First(); - // Find the customer's existing setup intents that should be cancelled. + // Find the customer's existing setup intents that should be canceled. var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer) .Where(si => si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action"); @@ -519,7 +521,7 @@ public class SubscriberService( await stripeAdapter.PaymentMethodAttachAsync(token, new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId }); - // Find the customer's existing setup intents that should be cancelled. + // Find the customer's existing setup intents that should be canceled. var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer) .Where(si => si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action"); @@ -637,7 +639,8 @@ public class SubscriberService( logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", taxInformation.Country, taxInformation.TaxId); - throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError"); + + throw new BadRequestException("billingTaxIdTypeInferenceError"); } } @@ -654,53 +657,84 @@ public class SubscriberService( logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", taxInformation.TaxId, taxInformation.Country); - throw new Exceptions.BadRequestException("billingInvalidTaxIdError"); + + throw new BadRequestException("billingInvalidTaxIdError"); + default: logger.LogError(e, "Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.", taxInformation.TaxId, taxInformation.Country, customer.Id); - throw new Exceptions.BadRequestException("billingTaxIdCreationError"); + + throw new BadRequestException("billingTaxIdCreationError"); } } } - if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + var subscription = + customer.Subscriptions.First(subscription => subscription.Id == subscriber.GatewaySubscriptionId); + + var isBusinessUseSubscriber = subscriber switch { - if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + Organization organization => organization.PlanType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families, + Provider => true, + _ => false + }; + + var setNonUSBusinessUseToReverseCharge = + featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge && isBusinessUseSubscriber) + { + switch (customer) { - var subscriptionGetOptions = new SubscriptionGetOptions + case { - Expand = ["customer.tax", "customer.tax_ids"] - }; - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); - var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); - var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); - var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription); - if (automaticTaxOptions?.AutomaticTax?.Enabled != null) + Address.Country: not "US", + TaxExempt: not StripeConstants.TaxExempt.Reverse + }: + await stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); + break; + case { - await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions); - } + Address.Country: "US", + TaxExempt: StripeConstants.TaxExempt.Reverse + }: + await stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.None }); + break; } - } - else - { - if (SubscriberIsEligibleForAutomaticTax(subscriber, customer)) + + if (!subscription.AutomaticTax.Enabled) { - await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, + await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } }); } + } + else + { + var automaticTaxShouldBeEnabled = subscriber switch + { + User => true, + Organization organization => organization.PlanType.GetProductTier() == ProductTierType.Families || + customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false), + Provider => customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false), + _ => false + }; - return; - - bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer) - => !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) && - (localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) && - localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; + if (automaticTaxShouldBeEnabled && !subscription.AutomaticTax.Enabled) + { + await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } } } diff --git a/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs b/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs index 310aced130..6affc57354 100644 --- a/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs +++ b/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs @@ -76,7 +76,7 @@ public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : I private bool ShouldBeEnabled(Customer customer) { - if (!customer.HasTaxLocationVerified()) + if (!customer.HasRecognizedTaxLocation()) { return false; } diff --git a/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs b/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs index e89fc6a3b3..615222259e 100644 --- a/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs +++ b/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs @@ -59,6 +59,6 @@ public class PersonalUseAutomaticTaxStrategy(IFeatureService featureService) : I private static bool ShouldBeEnabled(Customer customer) { - return customer.HasTaxLocationVerified(); + return customer.HasRecognizedTaxLocation(); } } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 707001ddcc..694521c14e 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -143,13 +143,13 @@ public static class FeatureFlagKeys public const string UsePricingService = "use-pricing-service"; public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features"; public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method"; - public const string PM19147_AutomaticTaxImprovements = "pm-19147-automatic-tax-improvements"; public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; public const string PM18770_EnableOrganizationBusinessUnitConversion = "pm-18770-enable-organization-business-unit-conversion"; public const string PM199566_UpdateMSPToChargeAutomatically = "pm-199566-update-msp-to-charge-automatically"; public const string PM19956_RequireProviderPaymentMethodDuringSetup = "pm-19956-require-provider-payment-method-during-setup"; public const string UseOrganizationWarningsService = "use-organization-warnings-service"; public const string PM20322_AllowTrialLength0 = "pm-20322-allow-trial-length-0"; + public const string PM21092_SetNonUSBusinessUseToReverseCharge = "pm-21092-set-non-us-business-use-to-reverse-charge"; /* Data Insights and Reporting Team */ public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application"; diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index 3fdb829cf4..af96b88ee6 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Models; using Bit.Core.Billing.Tax.Requests; using Bit.Core.Billing.Tax.Responses; using Bit.Core.Entities; -using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; @@ -30,8 +29,6 @@ public interface IPaymentService Task AdjustServiceAccountsAsync(Organization organization, Plan plan, int additionalServiceAccounts); Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false); Task ReinstateSubscriptionAsync(ISubscriber subscriber); - Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, - string paymentToken, TaxInfo taxInfo = null); Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount); Task GetBillingAsync(ISubscriber subscriber); Task GetBillingHistoryAsync(ISubscriber subscriber); diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 65c0525535..34be6d59c5 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1,13 +1,13 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Models.Business; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Contracts; -using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Requests; using Bit.Core.Billing.Tax.Responses; using Bit.Core.Billing.Tax.Services; @@ -38,7 +38,6 @@ public class StripePaymentService : IPaymentService private readonly IGlobalSettings _globalSettings; private readonly IFeatureService _featureService; private readonly ITaxService _taxService; - private readonly ISubscriberService _subscriberService; private readonly IPricingClient _pricingClient; private readonly IAutomaticTaxFactory _automaticTaxFactory; private readonly IAutomaticTaxStrategy _personalUseTaxStrategy; @@ -51,7 +50,6 @@ public class StripePaymentService : IPaymentService IGlobalSettings globalSettings, IFeatureService featureService, ITaxService taxService, - ISubscriberService subscriberService, IPricingClient pricingClient, IAutomaticTaxFactory automaticTaxFactory, [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy personalUseTaxStrategy) @@ -63,7 +61,6 @@ public class StripePaymentService : IPaymentService _globalSettings = globalSettings; _featureService = featureService; _taxService = taxService; - _subscriberService = subscriberService; _pricingClient = pricingClient; _automaticTaxFactory = automaticTaxFactory; _personalUseTaxStrategy = personalUseTaxStrategy; @@ -136,15 +133,68 @@ public class StripePaymentService : IPaymentService if (subscriptionUpdate is CompleteSubscriptionUpdate) { - if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + var setNonUSBusinessUseToReverseCharge = + _featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge) { - var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, updatedItemOptions.Select(x => x.Plan ?? x.Price)); - var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); - automaticTaxStrategy.SetUpdateOptions(subUpdateOptions, sub); + if (sub.Customer is + { + Address.Country: not "US", + TaxExempt: not StripeConstants.TaxExempt.Reverse + }) + { + await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId, + new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); + } + + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } - else + else if (sub.Customer.HasRecognizedTaxLocation()) { - subUpdateOptions.EnableAutomaticTax(sub.Customer, sub); + switch (subscriber) + { + case User: + { + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; + break; + } + case Organization: + { + if (sub.Customer.Address.Country == "US") + { + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; + } + else + { + var familyPriceIds = (await Task.WhenAll( + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually))) + .Select(plan => plan.PasswordManager.StripePlanId); + + var updateIsForPersonalUse = updatedItemOptions + .Select(option => option.Price) + .Intersect(familyPriceIds) + .Any(); + + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = updateIsForPersonalUse || sub.Customer.TaxIds.Any() + }; + } + + break; + } + case Provider: + { + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = sub.Customer.Address.Country == "US" || + sub.Customer.TaxIds.Any() + }; + break; + } + } } } @@ -202,7 +252,7 @@ public class StripePaymentService : IPaymentService } else if (!invoice.Paid) { - // Pay invoice with no charge to customer this completes the invoice immediately without waiting the scheduled 1h + // Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); paymentIntentClientSecret = null; } @@ -585,309 +635,6 @@ public class StripePaymentService : IPaymentService } } - public async Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, - string paymentToken, TaxInfo taxInfo = null) - { - if (subscriber == null) - { - throw new ArgumentNullException(nameof(subscriber)); - } - - if (subscriber.Gateway.HasValue && subscriber.Gateway.Value != GatewayType.Stripe) - { - throw new GatewayException("Switching from one payment type to another is not supported. " + - "Contact us for assistance."); - } - - var createdCustomer = false; - Braintree.Customer braintreeCustomer = null; - string stipeCustomerSourceToken = null; - string stipeCustomerPaymentMethodId = null; - var stripeCustomerMetadata = new Dictionary - { - { "region", _globalSettings.BaseServiceUri.CloudRegion } - }; - var stripePaymentMethod = paymentMethodType is PaymentMethodType.Card or PaymentMethodType.BankAccount; - - Customer customer = null; - - if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - var options = new CustomerGetOptions { Expand = ["sources", "tax", "subscriptions"] }; - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, options); - if (customer.Metadata?.Any() ?? false) - { - stripeCustomerMetadata = customer.Metadata; - } - } - - var hadBtCustomer = stripeCustomerMetadata.ContainsKey("btCustomerId"); - if (stripePaymentMethod) - { - if (paymentToken.StartsWith("pm_")) - { - stipeCustomerPaymentMethodId = paymentToken; - } - else - { - stipeCustomerSourceToken = paymentToken; - } - } - else if (paymentMethodType == PaymentMethodType.PayPal) - { - if (hadBtCustomer) - { - var pmResult = await _btGateway.PaymentMethod.CreateAsync(new Braintree.PaymentMethodRequest - { - CustomerId = stripeCustomerMetadata["btCustomerId"], - PaymentMethodNonce = paymentToken - }); - - if (pmResult.IsSuccess()) - { - var customerResult = await _btGateway.Customer.UpdateAsync( - stripeCustomerMetadata["btCustomerId"], new Braintree.CustomerRequest - { - DefaultPaymentMethodToken = pmResult.Target.Token - }); - - if (customerResult.IsSuccess() && customerResult.Target.PaymentMethods.Length > 0) - { - braintreeCustomer = customerResult.Target; - } - else - { - await _btGateway.PaymentMethod.DeleteAsync(pmResult.Target.Token); - hadBtCustomer = false; - } - } - else - { - hadBtCustomer = false; - } - } - - if (!hadBtCustomer) - { - var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest - { - PaymentMethodNonce = paymentToken, - Email = subscriber.BillingEmailAddress(), - Id = subscriber.BraintreeCustomerIdPrefix() + subscriber.Id.ToString("N").ToLower() + - Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false), - CustomFields = new Dictionary - { - [subscriber.BraintreeIdField()] = subscriber.Id.ToString(), - [subscriber.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion - } - }); - - if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0) - { - throw new GatewayException("Failed to create PayPal customer record."); - } - - braintreeCustomer = customerResult.Target; - } - } - else - { - throw new GatewayException("Payment method is not supported at this time."); - } - - if (stripeCustomerMetadata.ContainsKey("btCustomerId")) - { - if (braintreeCustomer?.Id != stripeCustomerMetadata["btCustomerId"]) - { - stripeCustomerMetadata["btCustomerId_old"] = stripeCustomerMetadata["btCustomerId"]; - } - - stripeCustomerMetadata["btCustomerId"] = braintreeCustomer?.Id; - } - else if (!string.IsNullOrWhiteSpace(braintreeCustomer?.Id)) - { - stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); - } - - try - { - if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber)) - { - taxInfo.TaxIdType = taxInfo.TaxIdType ?? - _taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber); - } - - if (customer == null) - { - customer = await _stripeAdapter.CustomerCreateAsync(new CustomerCreateOptions - { - Description = subscriber.BillingName(), - Email = subscriber.BillingEmailAddress(), - Metadata = stripeCustomerMetadata, - Source = stipeCustomerSourceToken, - PaymentMethod = stipeCustomerPaymentMethodId, - InvoiceSettings = new CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = stipeCustomerPaymentMethodId, - CustomFields = - [ - new CustomerInvoiceSettingsCustomFieldOptions() - { - Name = subscriber.SubscriberType(), - Value = subscriber.GetFormattedInvoiceName() - } - - ] - }, - Address = taxInfo == null ? null : new AddressOptions - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState - }, - TaxIdData = string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber) - ? [] - : [ - new CustomerTaxIdDataOptions - { - Type = taxInfo.TaxIdType, - Value = taxInfo.TaxIdNumber - } - ], - Expand = ["sources", "tax", "subscriptions"], - }); - - subscriber.Gateway = GatewayType.Stripe; - subscriber.GatewayCustomerId = customer.Id; - createdCustomer = true; - } - - if (!createdCustomer) - { - string defaultSourceId = null; - string defaultPaymentMethodId = null; - if (stripePaymentMethod) - { - if (!string.IsNullOrWhiteSpace(stipeCustomerSourceToken) && paymentToken.StartsWith("btok_")) - { - var bankAccount = await _stripeAdapter.BankAccountCreateAsync(customer.Id, new BankAccountCreateOptions - { - Source = paymentToken - }); - defaultSourceId = bankAccount.Id; - } - else if (!string.IsNullOrWhiteSpace(stipeCustomerPaymentMethodId)) - { - await _stripeAdapter.PaymentMethodAttachAsync(stipeCustomerPaymentMethodId, - new PaymentMethodAttachOptions { Customer = customer.Id }); - defaultPaymentMethodId = stipeCustomerPaymentMethodId; - } - } - - if (customer.Sources != null) - { - foreach (var source in customer.Sources.Where(s => s.Id != defaultSourceId)) - { - if (source is BankAccount) - { - await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); - } - else if (source is Card) - { - await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id); - } - } - } - - var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging(new PaymentMethodListOptions - { - Customer = customer.Id, - Type = "card" - }); - foreach (var cardMethod in cardPaymentMethods.Where(m => m.Id != defaultPaymentMethodId)) - { - await _stripeAdapter.PaymentMethodDetachAsync(cardMethod.Id, new PaymentMethodDetachOptions()); - } - - await _subscriberService.UpdateTaxInformation(subscriber, TaxInformation.From(taxInfo)); - - customer = await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions - { - Metadata = stripeCustomerMetadata, - DefaultSource = defaultSourceId, - InvoiceSettings = new CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = defaultPaymentMethodId, - CustomFields = - [ - new CustomerInvoiceSettingsCustomFieldOptions() - { - Name = subscriber.SubscriberType(), - Value = subscriber.GetFormattedInvoiceName() - } - ] - }, - Expand = ["tax", "subscriptions"] - }); - } - - if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) - { - if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) - { - var subscriptionGetOptions = new SubscriptionGetOptions - { - Expand = ["customer.tax", "customer.tax_ids"] - }; - var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); - - var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); - var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); - var subscriptionUpdateOptions = automaticTaxStrategy.GetUpdateOptions(subscription); - - if (subscriptionUpdateOptions != null) - { - _ = await _stripeAdapter.SubscriptionUpdateAsync( - subscriber.GatewaySubscriptionId, - subscriptionUpdateOptions); - } - } - } - else - { - if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) && - customer.Subscriptions.Any(sub => - sub.Id == subscriber.GatewaySubscriptionId && - !sub.AutomaticTax.Enabled) && - customer.HasTaxLocationVerified()) - { - var subscriptionUpdateOptions = new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, - DefaultTaxRates = [] - }; - - _ = await _stripeAdapter.SubscriptionUpdateAsync( - subscriber.GatewaySubscriptionId, - subscriptionUpdateOptions); - } - } - } - catch - { - if (braintreeCustomer != null && !hadBtCustomer) - { - await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); - } - throw; - } - - return createdCustomer; - } - public async Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount) { Customer customer = null; @@ -1018,7 +765,7 @@ public class StripePaymentService : IPaymentService var address = customer.Address; var taxId = customer.TaxIds?.FirstOrDefault(); - // Line1 is required, so if missing we're using the subscriber name + // Line1 is required, so if missing we're using the subscriber name, // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 if (address != null && string.IsNullOrWhiteSpace(address.Line1)) { diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index b1f78ed987..3fb134fda8 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -3,14 +3,11 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; using Bit.Core.Services; using Bit.Core.Settings; -using Bit.Core.Test.Billing.Tax.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Braintree; @@ -195,7 +192,7 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); ; + .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); } #endregion @@ -1029,7 +1026,7 @@ public class SubscriberServiceTests stripeAdapter .PaymentMethodListAutoPagingAsync(Arg.Any()) - .Returns(GetPaymentMethodsAsync(new List())); + .Returns(GetPaymentMethodsAsync(new List())); await sutProvider.Sut.RemovePaymentSource(organization); @@ -1061,7 +1058,7 @@ public class SubscriberServiceTests stripeAdapter .PaymentMethodListAutoPagingAsync(Arg.Any()) - .Returns(GetPaymentMethodsAsync(new List + .Returns(GetPaymentMethodsAsync(new List { new () { @@ -1086,8 +1083,8 @@ public class SubscriberServiceTests .PaymentMethodDetachAsync(cardId); } - private static async IAsyncEnumerable GetPaymentMethodsAsync( - IEnumerable paymentMethods) + private static async IAsyncEnumerable GetPaymentMethodsAsync( + IEnumerable paymentMethods) { foreach (var paymentMethod in paymentMethods) { @@ -1598,14 +1595,22 @@ public class SubscriberServiceTests City = "Example Town", State = "NY" }, - TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } + TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] }, + Subscriptions = new StripeList + { + Data = [ + new Subscription + { + Id = provider.GatewaySubscriptionId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + } }); var subscription = new Subscription { Items = new StripeList() }; sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) .Returns(subscription); - sutProvider.GetDependency().CreateAsync(Arg.Any()) - .Returns(new FakeAutomaticTaxStrategy(true)); await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); @@ -1623,6 +1628,98 @@ public class SubscriberServiceTests await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is( options => options.Type == "us_ein" && options.Value == taxInformation.TaxId)); + + await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); + } + + [Theory, BitAutoData] + public async Task UpdateTaxInformation_NonUser_ReverseCharge_MakesCorrectInvocations( + Provider provider, + SutProvider sutProvider) + { + var stripeAdapter = sutProvider.GetDependency(); + + var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } }; + + stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( + options => options.Expand.Contains("tax_ids"))).Returns(customer); + + var taxInformation = new TaxInformation( + "CA", + "12345", + "123456789", + "us_ein", + "123 Example St.", + null, + "Example Town", + "NY"); + + sutProvider.GetDependency() + .CustomerUpdateAsync( + Arg.Is(p => p == provider.GatewayCustomerId), + Arg.Is(options => + options.Address.Country == "CA" && + options.Address.PostalCode == "12345" && + options.Address.Line1 == "123 Example St." && + options.Address.Line2 == null && + options.Address.City == "Example Town" && + options.Address.State == "NY")) + .Returns(new Customer + { + Id = provider.GatewayCustomerId, + Address = new Address + { + Country = "CA", + PostalCode = "12345", + Line1 = "123 Example St.", + Line2 = null, + City = "Example Town", + State = "NY" + }, + TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] }, + Subscriptions = new StripeList + { + Data = [ + new Subscription + { + Id = provider.GatewaySubscriptionId, + CustomerId = provider.GatewayCustomerId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + } + }); + + var subscription = new Subscription { Items = new StripeList() }; + sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) + .Returns(subscription); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true); + + await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); + + await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + options => + options.Address.Country == taxInformation.Country && + options.Address.PostalCode == taxInformation.PostalCode && + options.Address.Line1 == taxInformation.Line1 && + options.Address.Line2 == taxInformation.Line2 && + options.Address.City == taxInformation.City && + options.Address.State == taxInformation.State)); + + await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1"); + + await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is( + options => options.Type == "us_ein" && + options.Value == taxInformation.TaxId)); + + await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, + Arg.Is(options => options.TaxExempt == StripeConstants.TaxExempt.Reverse)); + + await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); } #endregion