diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs index 841ce74b81..7eef4d7dbf 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs @@ -7,10 +7,12 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; +using Microsoft.Extensions.DependencyInjection; using Stripe; namespace Bit.Commercial.Core.AdminConsole.Providers; @@ -28,7 +30,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv private readonly ISubscriberService _subscriberService; private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery; private readonly IPricingClient _pricingClient; - private readonly IOrganizationAutomaticTaxStrategy _organizationAutomaticTaxStrategy; + private readonly IAutomaticTaxStrategy _automaticTaxStrategy; public RemoveOrganizationFromProviderCommand( IEventService eventService, @@ -42,7 +44,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv ISubscriberService subscriberService, IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, IPricingClient pricingClient, - IOrganizationAutomaticTaxStrategy organizationAutomaticTaxStrategy) + [FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy) { _eventService = eventService; _mailService = mailService; @@ -55,7 +57,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv _subscriberService = subscriberService; _hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery; _pricingClient = pricingClient; - _organizationAutomaticTaxStrategy = organizationAutomaticTaxStrategy; + _automaticTaxStrategy = automaticTaxStrategy; } public async Task RemoveOrganizationFromProvider( @@ -132,7 +134,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }] }; - await _organizationAutomaticTaxStrategy.SetCreateOptionsAsync(subscriptionCreateOptions, customer); + _automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); var subscription = await _stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index 508f0a0d6f..b44adf392f 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -14,6 +14,7 @@ using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Contracts; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -22,6 +23,7 @@ using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using CsvHelper; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; @@ -41,7 +43,8 @@ public class ProviderBillingService( IStripeAdapter stripeAdapter, ISubscriberService subscriberService, ITaxService taxService, - IOrganizationAutomaticTaxStrategy organizationAutomaticTaxStrategy) : IProviderBillingService + [FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy) + : IProviderBillingService { [RequireFeature(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal)] public async Task AddExistingOrganization( @@ -602,7 +605,7 @@ public class ProviderBillingService( ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations }; - await organizationAutomaticTaxStrategy.SetCreateOptionsAsync(subscriptionCreateOptions, customer); + automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); try { diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index d28f291599..48eda094e8 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -228,8 +228,8 @@ public class RemoveOrganizationFromProviderCommandTests Id = "subscription_id" }); - sutProvider.GetDependency() - .When(x => x.SetCreateOptionsAsync( + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( Arg.Is(options => options.Customer == organization.GatewayCustomerId && options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index d727f0c7d9..94b6e70edf 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -1018,8 +1018,8 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency() - .When(x => x.SetCreateOptionsAsync( + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( Arg.Is(options => options.Customer == "customer_id") , Arg.Is(p => p == customer))) diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index d91454af23..7428f10379 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -1,7 +1,7 @@ using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -21,8 +21,7 @@ public class UpcomingInvoiceHandler( IStripeEventUtilityService stripeEventUtilityService, IUserRepository userRepository, IValidateSponsorshipCommand validateSponsorshipCommand, - IIndividualAutomaticTaxStrategy individualAutomaticTaxStrategy, - IOrganizationAutomaticTaxStrategy organizationAutomaticTaxStrategy) + IAutomaticTaxFactory automaticTaxFactory) : IUpcomingInvoiceHandler { public async Task HandleAsync(Event parsedEvent) @@ -137,9 +136,9 @@ public class UpcomingInvoiceHandler( private async Task TryEnableAutomaticTaxAsync(Subscription subscription) { - var updateOptions = subscription.IsOrganization() - ? await organizationAutomaticTaxStrategy.GetUpdateOptionsAsync(subscription) - : individualAutomaticTaxStrategy.GetUpdateOptions(subscription); + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id)); + var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); + var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription); if (updateOptions == null) { diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 016b671fdf..17285e0676 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -19,8 +19,9 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddTransient(); services.AddTransient(); - services.AddTransient(); - services.AddTransient(); + services.AddKeyedTransient(AutomaticTaxFactory.PersonalUse); + services.AddKeyedTransient(AutomaticTaxFactory.BusinessUse); + services.AddTransient(); services.AddLicenseServices(); services.AddPricingClient(); } diff --git a/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs b/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs index dd8cf372e9..7e7a98ccdf 100644 --- a/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs +++ b/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs @@ -5,7 +5,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Migration.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Repositories; -using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; @@ -21,8 +20,7 @@ public class OrganizationMigrator( IMigrationTrackerCache migrationTrackerCache, IOrganizationRepository organizationRepository, IPricingClient pricingClient, - IStripeAdapter stripeAdapter, - IOrganizationAutomaticTaxStrategy automaticTaxStrategy) : IOrganizationMigrator + IStripeAdapter stripeAdapter) : IOrganizationMigrator { private const string _cancellationComment = "Cancelled as part of provider migration to Consolidated Billing"; @@ -160,145 +158,6 @@ public class OrganizationMigrator( #endregion - #region Reverse - - private async Task RemoveMigrationRecordAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Removing migration record for organization ({OrganizationID})", organization.Id); - - var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id); - - if (migrationRecord != null) - { - await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord); - - logger.LogInformation( - "CB: Removed migration record for organization ({OrganizationID})", - organization.Id); - } - else - { - logger.LogInformation("CB: Did not remove migration record for organization ({OrganizationID}) as it does not exist", organization.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, ClientMigrationProgress.Reversed); - } - - private async Task RecreateSubscriptionAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Recreating subscription for organization ({OrganizationID})", organization.Id); - - if (!string.IsNullOrEmpty(organization.GatewaySubscriptionId)) - { - if (string.IsNullOrEmpty(organization.GatewayCustomerId)) - { - logger.LogError( - "CB: Cannot recreate subscription for organization ({OrganizationID}) as it does not have a Stripe customer", - organization.Id); - - throw new Exception(); - } - - var customer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, - new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] }); - - var collectionMethod = - customer.DefaultSource != null || - customer.InvoiceSettings?.DefaultPaymentMethod != null || - customer.Metadata.ContainsKey(Utilities.BraintreeCustomerIdKey) - ? StripeConstants.CollectionMethod.ChargeAutomatically - : StripeConstants.CollectionMethod.SendInvoice; - - var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); - - var items = new List - { - new () - { - Price = plan.PasswordManager.StripeSeatPlanId, - Quantity = organization.Seats - } - }; - - if (organization.MaxStorageGb.HasValue && plan.PasswordManager.BaseStorageGb.HasValue && organization.MaxStorageGb.Value > plan.PasswordManager.BaseStorageGb.Value) - { - var additionalStorage = organization.MaxStorageGb.Value - plan.PasswordManager.BaseStorageGb.Value; - - items.Add(new SubscriptionItemOptions - { - Price = plan.PasswordManager.StripeStoragePlanId, - Quantity = additionalStorage - }); - } - - var subscriptionCreateOptions = new SubscriptionCreateOptions - { - Customer = customer.Id, - CollectionMethod = collectionMethod, - DaysUntilDue = collectionMethod == StripeConstants.CollectionMethod.SendInvoice ? 30 : null, - Items = items, - Metadata = new Dictionary - { - [organization.GatewayIdField()] = organization.Id.ToString() - }, - OffSession = true, - ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, - TrialPeriodDays = plan.TrialPeriodDays - }; - - await automaticTaxStrategy.SetCreateOptionsAsync(subscriptionCreateOptions, customer); - - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); - - organization.GatewaySubscriptionId = subscription.Id; - - await organizationRepository.ReplaceAsync(organization); - - logger.LogInformation("CB: Recreated subscription for organization ({OrganizationID})", organization.Id); - } - else - { - logger.LogInformation( - "CB: Did not recreate subscription for organization ({OrganizationID}) as it already exists", - organization.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.RecreatedSubscription); - } - - private async Task ReverseOrganizationUpdateAsync(Guid providerId, Organization organization) - { - var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id); - - if (migrationRecord == null) - { - logger.LogError( - "CB: Cannot reverse migration for organization ({OrganizationID}) as it does not have a migration record", - organization.Id); - - throw new Exception(); - } - - var plan = await pricingClient.GetPlanOrThrow(migrationRecord.PlanType); - - ResetOrganizationPlan(organization, plan); - organization.MaxStorageGb = migrationRecord.MaxStorageGb; - organization.ExpirationDate = migrationRecord.ExpirationDate; - organization.MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats; - organization.Status = migrationRecord.Status; - - await organizationRepository.ReplaceAsync(organization); - - logger.LogInformation("CB: Reversed organization ({OrganizationID}) updates", - organization.Id); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.ResetOrganization); - } - - #endregion - #region Shared private static void ResetOrganizationPlan(Organization organization, Plan plan) diff --git a/src/Core/Billing/Services/Contracts/AutomaticTaxFactoryParameters.cs b/src/Core/Billing/Services/Contracts/AutomaticTaxFactoryParameters.cs new file mode 100644 index 0000000000..19a4f0bdfa --- /dev/null +++ b/src/Core/Billing/Services/Contracts/AutomaticTaxFactoryParameters.cs @@ -0,0 +1,30 @@ +#nullable enable +using Bit.Core.Billing.Enums; +using Bit.Core.Entities; + +namespace Bit.Core.Billing.Services.Contracts; + +public class AutomaticTaxFactoryParameters +{ + public AutomaticTaxFactoryParameters(PlanType planType) + { + PlanType = planType; + } + + public AutomaticTaxFactoryParameters(ISubscriber subscriber, IEnumerable prices) + { + Subscriber = subscriber; + Prices = prices; + } + + public AutomaticTaxFactoryParameters(IEnumerable prices) + { + Prices = prices; + } + + public ISubscriber? Subscriber { get; init; } + + public PlanType? PlanType { get; init; } + + public IEnumerable? Prices { get; init; } +} diff --git a/src/Core/Billing/Services/IAutomaticTaxFactory.cs b/src/Core/Billing/Services/IAutomaticTaxFactory.cs new file mode 100644 index 0000000000..3c853ac0d6 --- /dev/null +++ b/src/Core/Billing/Services/IAutomaticTaxFactory.cs @@ -0,0 +1,8 @@ +using Bit.Core.Billing.Services.Contracts; + +namespace Bit.Core.Billing.Services; + +public interface IAutomaticTaxFactory +{ + Task CreateAsync(AutomaticTaxFactoryParameters parameters); +} diff --git a/src/Core/Billing/Services/IIndividualAutomaticTaxStrategy.cs b/src/Core/Billing/Services/IAutomaticTaxStrategy.cs similarity index 92% rename from src/Core/Billing/Services/IIndividualAutomaticTaxStrategy.cs rename to src/Core/Billing/Services/IAutomaticTaxStrategy.cs index 4c2422d4fa..32be2a2750 100644 --- a/src/Core/Billing/Services/IIndividualAutomaticTaxStrategy.cs +++ b/src/Core/Billing/Services/IAutomaticTaxStrategy.cs @@ -3,7 +3,7 @@ using Stripe; namespace Bit.Core.Billing.Services; -public interface IIndividualAutomaticTaxStrategy +public interface IAutomaticTaxStrategy { /// /// diff --git a/src/Core/Billing/Services/IOrganizationAutomaticTaxStrategy.cs b/src/Core/Billing/Services/IOrganizationAutomaticTaxStrategy.cs deleted file mode 100644 index 5f0bb55e3d..0000000000 --- a/src/Core/Billing/Services/IOrganizationAutomaticTaxStrategy.cs +++ /dev/null @@ -1,19 +0,0 @@ -#nullable enable -using Stripe; - -namespace Bit.Core.Billing.Services; - -public interface IOrganizationAutomaticTaxStrategy -{ - /// - /// - /// - /// - /// - /// Returns if changes are to be applied to the subscription, returns null - /// otherwise. - /// - Task GetUpdateOptionsAsync(Subscription subscription); - Task SetCreateOptionsAsync(SubscriptionCreateOptions options, Customer customer); - Task SetUpdateOptionsAsync(SubscriptionUpdateOptions options, Subscription subscription); -} diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/AutomaticTaxFactory.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/AutomaticTaxFactory.cs new file mode 100644 index 0000000000..2751a158c3 --- /dev/null +++ b/src/Core/Billing/Services/Implementations/AutomaticTax/AutomaticTaxFactory.cs @@ -0,0 +1,48 @@ +#nullable enable +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services.Contracts; +using Bit.Core.Entities; + +namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; + +public class AutomaticTaxFactory(IPricingClient pricingClient) : IAutomaticTaxFactory +{ + public const string BusinessUse = "business-use"; + public const string PersonalUse = "personal-use"; + + private readonly Lazy>> _personalUsePlansTask = new(async () => + { + var plans = await Task.WhenAll( + pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), + pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually)); + + return plans.Select(plan => plan.PasswordManager.StripePlanId); + }); + + public async Task CreateAsync(AutomaticTaxFactoryParameters parameters) + { + if (parameters.Subscriber is User) + { + return new PersonalUseAutomaticTaxStrategy(); + } + + if (parameters.PlanType.HasValue) + { + var plan = await pricingClient.GetPlanOrThrow(parameters.PlanType.Value); + return plan.CanBeUsedByBusiness + ? new BusinessUseAutomaticTaxStrategy() + : new PersonalUseAutomaticTaxStrategy(); + } + + var personalUsePlans = await _personalUsePlansTask.Value; + var plans = await pricingClient.ListPlans(); + + if (personalUsePlans.Any(x => plans.Any(y => y.PasswordManager.StripePlanId == x))) + { + return new PersonalUseAutomaticTaxStrategy(); + } + + return new BusinessUseAutomaticTaxStrategy(); + } +} diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategy.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategy.cs new file mode 100644 index 0000000000..427067ffdf --- /dev/null +++ b/src/Core/Billing/Services/Implementations/AutomaticTax/BusinessUseAutomaticTaxStrategy.cs @@ -0,0 +1,62 @@ +#nullable enable +using Bit.Core.Billing.Extensions; +using Stripe; + +namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; + +public class BusinessUseAutomaticTaxStrategy : IAutomaticTaxStrategy +{ + public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription) + { + var shouldBeEnabled = ShouldBeEnabled(subscription.Customer); + if (subscription.AutomaticTax.Enabled == shouldBeEnabled) + { + return null; + } + + var options = new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = shouldBeEnabled + }, + DefaultTaxRates = [] + }; + + return options; + } + + public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) + { + options.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = ShouldBeEnabled(customer) + }; + } + + public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription) + { + var shouldBeEnabled = ShouldBeEnabled(subscription.Customer); + + if (subscription.AutomaticTax.Enabled == shouldBeEnabled) + { + return; + } + + options.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = shouldBeEnabled + }; + options.DefaultTaxRates = []; + } + + private bool ShouldBeEnabled(Customer customer) + { + if (!customer.HasTaxLocationVerified()) + { + return false; + } + + return customer.Address.Country == "US"; + } +} diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/OrganizationAutomaticTaxStrategy.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/OrganizationAutomaticTaxStrategy.cs deleted file mode 100644 index cd7597b7ba..0000000000 --- a/src/Core/Billing/Services/Implementations/AutomaticTax/OrganizationAutomaticTaxStrategy.cs +++ /dev/null @@ -1,97 +0,0 @@ -#nullable enable -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Pricing; -using Stripe; - -namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; - -public class OrganizationAutomaticTaxStrategy( - IPricingClient pricingClient) : IOrganizationAutomaticTaxStrategy -{ - private readonly Lazy>> _familyPriceIdsTask = new(async () => - { - var plans = await Task.WhenAll( - pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), - pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually)); - - return plans.Select(plan => plan.PasswordManager.StripePlanId); - }); - - public async Task GetUpdateOptionsAsync(Subscription subscription) - { - var shouldBeEnabled = await ShouldBeEnabledAsync(subscription); - - var options = new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = shouldBeEnabled - }, - DefaultTaxRates = [] - }; - - return options; - } - - public async Task SetCreateOptionsAsync(SubscriptionCreateOptions options, Customer customer) - { - options.AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = await ShouldBeEnabledAsync(options, customer) - }; - } - - public async Task SetUpdateOptionsAsync(SubscriptionUpdateOptions options, Subscription subscription) - { - var shouldBeEnabled = await ShouldBeEnabledAsync(subscription); - - if (subscription.AutomaticTax.Enabled == shouldBeEnabled) - { - return; - } - - options.AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = shouldBeEnabled - }; - options.DefaultTaxRates = []; - } - - private async Task ShouldBeEnabledAsync(Subscription subscription) - { - if (!subscription.Customer.HasTaxLocationVerified()) - { - return false; - } - - bool shouldBeEnabled; - if (subscription.Customer.Address.Country == "US") - { - shouldBeEnabled = true; - } - else - { - var familyPriceIds = await _familyPriceIdsTask.Value; - shouldBeEnabled = subscription.Items.Select(item => item.Price.Id).Intersect(familyPriceIds).Any(); - } - - return shouldBeEnabled; - } - - private async Task ShouldBeEnabledAsync(SubscriptionCreateOptions options, Customer customer) - { - if (!customer.HasTaxLocationVerified()) - { - return false; - } - - if (customer.Address.Country == "US") - { - return true; - } - - var familyPriceIds = await _familyPriceIdsTask.Value; - return options.Items.Select(item => item.Price).Intersect(familyPriceIds).Any(); - } -} diff --git a/src/Core/Billing/Services/Implementations/AutomaticTax/IndividualAutomaticTaxStrategy.cs b/src/Core/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategy.cs similarity index 94% rename from src/Core/Billing/Services/Implementations/AutomaticTax/IndividualAutomaticTaxStrategy.cs rename to src/Core/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategy.cs index 230cb5be1c..5532b61097 100644 --- a/src/Core/Billing/Services/Implementations/AutomaticTax/IndividualAutomaticTaxStrategy.cs +++ b/src/Core/Billing/Services/Implementations/AutomaticTax/PersonalUseAutomaticTaxStrategy.cs @@ -4,7 +4,7 @@ using Stripe; namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; -public class IndividualAutomaticTaxStrategy : IIndividualAutomaticTaxStrategy +public class PersonalUseAutomaticTaxStrategy : IAutomaticTaxStrategy { public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer) { diff --git a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs index df5322fd73..139fddaf5c 100644 --- a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs +++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs @@ -4,6 +4,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -31,7 +32,7 @@ public class OrganizationBillingService( IStripeAdapter stripeAdapter, ISubscriberService subscriberService, ITaxService taxService, - IOrganizationAutomaticTaxStrategy organizationAutomaticTaxStrategy) : IOrganizationBillingService + IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService { public async Task Finalize(OrganizationSale sale) { @@ -370,15 +371,6 @@ public class OrganizationBillingService( } } - var customerHasTaxInfo = customer is - { - Address: - { - Country: not null and not "", - PostalCode: not null and not "" - } - }; - var subscriptionCreateOptions = new SubscriptionCreateOptions { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, @@ -392,7 +384,9 @@ public class OrganizationBillingService( TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays }; - await organizationAutomaticTaxStrategy.SetCreateOptionsAsync(subscriptionCreateOptions, customer); + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType); + var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); + automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); } diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index 3aeeec4ce7..f1d4c95557 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -2,6 +2,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; +using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -9,6 +10,7 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Braintree; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; using Customer = Stripe.Customer; @@ -26,7 +28,7 @@ public class PremiumUserBillingService( IStripeAdapter stripeAdapter, ISubscriberService subscriberService, IUserRepository userRepository, - IIndividualAutomaticTaxStrategy individualAutomaticTaxStrategy) : IPremiumUserBillingService + [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService { public async Task Credit(User user, decimal amount) { @@ -332,7 +334,7 @@ public class PremiumUserBillingService( OffSession = true }; - individualAutomaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 2031fad17f..441bda6bbe 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -1,6 +1,7 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -25,8 +26,7 @@ public class SubscriberService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ITaxService taxService, - IIndividualAutomaticTaxStrategy individualAutomaticTaxStrategy, - IOrganizationAutomaticTaxStrategy organizationAutomaticTaxStrategy) : ISubscriberService + IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService { public async Task CancelSubscription( ISubscriber subscriber, @@ -666,9 +666,9 @@ public class SubscriberService( if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) { var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - var automaticTaxOptions = subscriber.IsUser() - ? individualAutomaticTaxStrategy.GetUpdateOptions(subscription) - : await organizationAutomaticTaxStrategy.GetUpdateOptionsAsync(subscription); + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); + var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); + var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription); if (automaticTaxOptions?.AutomaticTax?.Enabled != null) { await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions); diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 3da1bc1ef3..02c76f6a35 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -9,6 +9,7 @@ using Bit.Core.Billing.Models.Api.Responses; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Contracts; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -36,8 +37,7 @@ public class StripePaymentService : IPaymentService private readonly ITaxService _taxService; private readonly ISubscriberService _subscriberService; private readonly IPricingClient _pricingClient; - private readonly IIndividualAutomaticTaxStrategy _individualAutomaticTaxStrategy; - private readonly IOrganizationAutomaticTaxStrategy _organizationAutomaticTaxStrategy; + private readonly IAutomaticTaxFactory _automaticTaxFactory; public StripePaymentService( ITransactionRepository transactionRepository, @@ -49,8 +49,7 @@ public class StripePaymentService : IPaymentService ITaxService taxService, ISubscriberService subscriberService, IPricingClient pricingClient, - IIndividualAutomaticTaxStrategy individualAutomaticTaxStrategy, - IOrganizationAutomaticTaxStrategy organizationAutomaticTaxStrategy) + IAutomaticTaxFactory automaticTaxFactory) { _transactionRepository = transactionRepository; _logger = logger; @@ -61,8 +60,7 @@ public class StripePaymentService : IPaymentService _taxService = taxService; _subscriberService = subscriberService; _pricingClient = pricingClient; - _individualAutomaticTaxStrategy = individualAutomaticTaxStrategy; - _organizationAutomaticTaxStrategy = organizationAutomaticTaxStrategy; + _automaticTaxFactory = automaticTaxFactory; } private async Task ChangeOrganizationSponsorship( @@ -130,7 +128,9 @@ public class StripePaymentService : IPaymentService new SubscriptionPendingInvoiceItemIntervalOptions { Interval = "month" }; } - await _organizationAutomaticTaxStrategy.SetUpdateOptionsAsync(subUpdateOptions, sub); + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, sub.Items.Select(x => x.Price.Id)); + var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); + automaticTaxStrategy.SetUpdateOptions(subUpdateOptions, sub); if (!subscriptionUpdate.UpdateNeeded(sub)) { @@ -821,9 +821,9 @@ public class StripePaymentService : IPaymentService { var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - var subscriptionUpdateOptions = subscriber is User - ? _individualAutomaticTaxStrategy.GetUpdateOptions(subscription) - : await _organizationAutomaticTaxStrategy.GetUpdateOptionsAsync(subscription); + var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); + var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); + var subscriptionUpdateOptions = automaticTaxStrategy.GetUpdateOptions(subscription); if (subscriptionUpdateOptions != null) {