diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs index 2187f98a80..a0a517b798 100644 --- a/src/Billing/Controllers/StripeController.cs +++ b/src/Billing/Controllers/StripeController.cs @@ -1,6 +1,7 @@ using Bit.Billing.Constants; using Bit.Billing.Models; using Bit.Billing.Services; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.Context; using Bit.Core.Enums; @@ -23,6 +24,7 @@ using Event = Stripe.Event; using JsonSerializer = System.Text.Json.JsonSerializer; using PaymentMethod = Stripe.PaymentMethod; using Subscription = Stripe.Subscription; +using TaxRate = Bit.Core.Entities.TaxRate; using Transaction = Bit.Core.Entities.Transaction; using TransactionType = Bit.Core.Enums.TransactionType; @@ -52,6 +54,7 @@ public class StripeController : Controller private readonly GlobalSettings _globalSettings; private readonly IStripeEventService _stripeEventService; private readonly IStripeFacade _stripeFacade; + private readonly IFeatureService _featureService; public StripeController( GlobalSettings globalSettings, @@ -70,7 +73,8 @@ public class StripeController : Controller IUserRepository userRepository, ICurrentContext currentContext, IStripeEventService stripeEventService, - IStripeFacade stripeFacade) + IStripeFacade stripeFacade, + IFeatureService featureService) { _billingSettings = billingSettings?.Value; _hostingEnvironment = hostingEnvironment; @@ -97,6 +101,7 @@ public class StripeController : Controller _globalSettings = globalSettings; _stripeEventService = stripeEventService; _stripeFacade = stripeFacade; + _featureService = featureService; } [HttpPost("webhook")] @@ -242,17 +247,29 @@ public class StripeController : Controller $"Received null Subscription from Stripe for ID '{invoice.SubscriptionId}' while processing Event with ID '{parsedEvent.Id}'"); } - if (!subscription.AutomaticTax.Enabled) + var pm5766AutomaticTaxIsEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax); + if (pm5766AutomaticTaxIsEnabled) { - subscription = await _stripeFacade.UpdateSubscription(subscription.Id, - new SubscriptionUpdateOptions - { - DefaultTaxRates = new List(), - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); + var customer = await _stripeFacade.GetCustomer(subscription.CustomerId); + if (!subscription.AutomaticTax.Enabled && + !string.IsNullOrEmpty(customer.Address?.PostalCode) && + !string.IsNullOrEmpty(customer.Address?.Country)) + { + subscription = await _stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + DefaultTaxRates = new List(), + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } } - var (organizationId, userId) = GetIdsFromMetaData(subscription.Metadata); + + var updatedSubscription = pm5766AutomaticTaxIsEnabled + ? subscription + : await VerifyCorrectTaxRateForCharge(invoice, subscription); + + var (organizationId, userId) = GetIdsFromMetaData(updatedSubscription.Metadata); var invoiceLineItemDescriptions = invoice.Lines.Select(i => i.Description).ToList(); @@ -273,7 +290,7 @@ public class StripeController : Controller if (organizationId.HasValue) { - if (IsSponsoredSubscription(subscription)) + if (IsSponsoredSubscription(updatedSubscription)) { await _validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId.Value); } @@ -321,22 +338,20 @@ public class StripeController : Controller Tuple ids = null; Subscription subscription = null; - var subscriptionService = new SubscriptionService(); if (charge.InvoiceId != null) { - var invoiceService = new InvoiceService(); - var invoice = await invoiceService.GetAsync(charge.InvoiceId); + var invoice = await _stripeFacade.GetInvoice(charge.InvoiceId); if (invoice?.SubscriptionId != null) { - subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); ids = GetIdsFromMetaData(subscription?.Metadata); } } if (subscription == null || ids == null || (ids.Item1.HasValue && ids.Item2.HasValue)) { - var subscriptions = await subscriptionService.ListAsync(new SubscriptionListOptions + var subscriptions = await _stripeFacade.ListSubscriptions(new SubscriptionListOptions { Customer = charge.CustomerId }); @@ -490,8 +505,7 @@ public class StripeController : Controller var invoice = await _stripeEventService.GetInvoice(parsedEvent, true); if (invoice.Paid && invoice.BillingReason == "subscription_create") { - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); if (subscription?.Status == StripeSubscriptionStatus.Active) { if (DateTime.UtcNow - invoice.Created < TimeSpan.FromMinutes(1)) @@ -596,7 +610,6 @@ public class StripeController : Controller return; } - var subscriptionService = new SubscriptionService(); var subscriptionListOptions = new SubscriptionListOptions { Customer = paymentMethod.CustomerId, @@ -607,7 +620,7 @@ public class StripeController : Controller StripeList unpaidSubscriptions; try { - unpaidSubscriptions = await subscriptionService.ListAsync(subscriptionListOptions); + unpaidSubscriptions = await _stripeFacade.ListSubscriptions(subscriptionListOptions); } catch (Exception e) { @@ -702,8 +715,7 @@ public class StripeController : Controller private async Task AttemptToPayInvoiceAsync(Invoice invoice, bool attemptToPayWithStripe = false) { - var customerService = new CustomerService(); - var customer = await customerService.GetAsync(invoice.CustomerId); + var customer = await _stripeFacade.GetCustomer(invoice.CustomerId); if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) { @@ -728,8 +740,7 @@ public class StripeController : Controller return false; } - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); var ids = GetIdsFromMetaData(subscription?.Metadata); if (!ids.Item1.HasValue && !ids.Item2.HasValue) { @@ -797,10 +808,9 @@ public class StripeController : Controller return false; } - var invoiceService = new InvoiceService(); try { - await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions + await _stripeFacade.UpdateInvoice(invoice.Id, new InvoiceUpdateOptions { Metadata = new Dictionary { @@ -809,14 +819,14 @@ public class StripeController : Controller transactionResult.Target.PayPalDetails?.AuthorizationId } }); - await invoiceService.PayAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); + await _stripeFacade.PayInvoice(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); } catch (Exception e) { await _btGateway.Transaction.RefundAsync(transactionResult.Target.Id); if (e.Message.Contains("Invoice is already paid")) { - await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions + await _stripeFacade.UpdateInvoice(invoice.Id, new InvoiceUpdateOptions { Metadata = invoice.Metadata }); @@ -834,8 +844,7 @@ public class StripeController : Controller { try { - var invoiceService = new InvoiceService(); - await invoiceService.PayAsync(invoice.Id); + await _stripeFacade.PayInvoice(invoice.Id); return true; } catch (Exception e) @@ -855,6 +864,41 @@ public class StripeController : Controller invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null; } + private async Task VerifyCorrectTaxRateForCharge(Invoice invoice, Subscription subscription) + { + if (string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.Country) || + string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.PostalCode)) + { + return subscription; + } + + var localBitwardenTaxRates = await _taxRateRepository.GetByLocationAsync( + new TaxRate() + { + Country = invoice.CustomerAddress.Country, + PostalCode = invoice.CustomerAddress.PostalCode + } + ); + + if (!localBitwardenTaxRates.Any()) + { + return subscription; + } + + var stripeTaxRate = await _stripeFacade.GetTaxRate(localBitwardenTaxRates.First().Id); + if (stripeTaxRate == null || subscription.DefaultTaxRates.Any(x => x == stripeTaxRate)) + { + return subscription; + } + + subscription.DefaultTaxRates = new List { stripeTaxRate }; + + var subscriptionOptions = new SubscriptionUpdateOptions { DefaultTaxRates = new List { stripeTaxRate.Id } }; + subscription = await _stripeFacade.UpdateSubscription(subscription.Id, subscriptionOptions); + + return subscription; + } + private static bool IsSponsoredSubscription(Subscription subscription) => StaticStore.SponsoredPlans.Any(p => p.StripePlanId == subscription.Id); @@ -862,8 +906,7 @@ public class StripeController : Controller { if (!invoice.Paid && invoice.AttemptCount > 1 && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice)) { - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); // attempt count 4 = 11 days after initial failure if (invoice.AttemptCount <= 3 || !subscription.Items.Any(i => i.Price.Id is PremiumPlanId or PremiumPlanIdAppStore)) @@ -873,23 +916,20 @@ public class StripeController : Controller } } - private async Task CancelSubscription(string subscriptionId) - { - await new SubscriptionService().CancelAsync(subscriptionId, new SubscriptionCancelOptions()); - } + private async Task CancelSubscription(string subscriptionId) => + await _stripeFacade.CancelSubscription(subscriptionId, new SubscriptionCancelOptions()); private async Task VoidOpenInvoices(string subscriptionId) { - var invoiceService = new InvoiceService(); var options = new InvoiceListOptions { Status = StripeInvoiceStatus.Open, Subscription = subscriptionId }; - var invoices = invoiceService.List(options); + var invoices = await _stripeFacade.ListInvoices(options); foreach (var invoice in invoices) { - await invoiceService.VoidInvoiceAsync(invoice.Id); + await _stripeFacade.VoidInvoice(invoice.Id); } } diff --git a/src/Billing/Services/IStripeFacade.cs b/src/Billing/Services/IStripeFacade.cs index 4a49c75ea2..836f15aed0 100644 --- a/src/Billing/Services/IStripeFacade.cs +++ b/src/Billing/Services/IStripeFacade.cs @@ -22,12 +22,40 @@ public interface IStripeFacade RequestOptions requestOptions = null, CancellationToken cancellationToken = default); + Task> ListInvoices( + InvoiceListOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); + + Task UpdateInvoice( + string invoiceId, + InvoiceUpdateOptions invoiceGetOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); + + Task PayInvoice( + string invoiceId, + InvoicePayOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); + + Task VoidInvoice( + string invoiceId, + InvoiceVoidOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); + Task GetPaymentMethod( string paymentMethodId, PaymentMethodGetOptions paymentMethodGetOptions = null, RequestOptions requestOptions = null, CancellationToken cancellationToken = default); + Task> ListSubscriptions( + SubscriptionListOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); + Task GetSubscription( string subscriptionId, SubscriptionGetOptions subscriptionGetOptions = null, @@ -39,4 +67,16 @@ public interface IStripeFacade SubscriptionUpdateOptions subscriptionGetOptions = null, RequestOptions requestOptions = null, CancellationToken cancellationToken = default); + + Task CancelSubscription( + string subscriptionId, + SubscriptionCancelOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); + + Task GetTaxRate( + string taxRateId, + TaxRateGetOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); } diff --git a/src/Billing/Services/Implementations/StripeFacade.cs b/src/Billing/Services/Implementations/StripeFacade.cs index db60621029..fb42030e0c 100644 --- a/src/Billing/Services/Implementations/StripeFacade.cs +++ b/src/Billing/Services/Implementations/StripeFacade.cs @@ -9,6 +9,7 @@ public class StripeFacade : IStripeFacade private readonly InvoiceService _invoiceService = new(); private readonly PaymentMethodService _paymentMethodService = new(); private readonly SubscriptionService _subscriptionService = new(); + private readonly TaxRateService _taxRateService = new(); public async Task GetCharge( string chargeId, @@ -31,6 +32,31 @@ public class StripeFacade : IStripeFacade CancellationToken cancellationToken = default) => await _invoiceService.GetAsync(invoiceId, invoiceGetOptions, requestOptions, cancellationToken); + public async Task> ListInvoices( + InvoiceListOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + await _invoiceService.ListAsync(options, requestOptions, cancellationToken); + + public async Task UpdateInvoice( + string invoiceId, + InvoiceUpdateOptions invoiceGetOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + await _invoiceService.UpdateAsync(invoiceId, invoiceGetOptions, requestOptions, cancellationToken); + + public async Task PayInvoice(string invoiceId, InvoicePayOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + await _invoiceService.PayAsync(invoiceId, options, requestOptions, cancellationToken); + + public async Task VoidInvoice( + string invoiceId, + InvoiceVoidOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + await _invoiceService.VoidInvoiceAsync(invoiceId, options, requestOptions, cancellationToken); + public async Task GetPaymentMethod( string paymentMethodId, PaymentMethodGetOptions paymentMethodGetOptions = null, @@ -38,6 +64,11 @@ public class StripeFacade : IStripeFacade CancellationToken cancellationToken = default) => await _paymentMethodService.GetAsync(paymentMethodId, paymentMethodGetOptions, requestOptions, cancellationToken); + public async Task> ListSubscriptions(SubscriptionListOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + await _subscriptionService.ListAsync(options, requestOptions, cancellationToken); + public async Task GetSubscription( string subscriptionId, SubscriptionGetOptions subscriptionGetOptions = null, @@ -51,4 +82,18 @@ public class StripeFacade : IStripeFacade RequestOptions requestOptions = null, CancellationToken cancellationToken = default) => await _subscriptionService.UpdateAsync(subscriptionId, subscriptionUpdateOptions, requestOptions, cancellationToken); + + public async Task CancelSubscription( + string subscriptionId, + SubscriptionCancelOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + await _subscriptionService.CancelAsync(subscriptionId, options, requestOptions, cancellationToken); + + public async Task GetTaxRate( + string taxRateId, + TaxRateGetOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + await _taxRateService.GetAsync(taxRateId, options, requestOptions, cancellationToken); } diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 8ab5293e95..fcd6c78e38 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -139,8 +139,13 @@ public class OrganizationService : IOrganizationService } await _paymentService.SaveTaxInfoAsync(organization, taxInfo); - var updated = await _paymentService.UpdatePaymentMethodAsync(organization, - paymentMethodType, paymentToken, taxInfo); + var updated = await _paymentService.UpdatePaymentMethodAsync( + organization, + paymentMethodType, + paymentToken, + _featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax) + ? taxInfo + : null); if (updated) { await ReplaceAndUpdateCacheAsync(organization); diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 3235e1db3c..1d5073df69 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -116,6 +116,8 @@ public static class FeatureFlagKeys /// public const string FlexibleCollectionsMigration = "flexible-collections-migration"; + public const string PM5766AutomaticTax = "PM-5766-automatic-tax"; + public static List GetAllKeys() { return typeof(FeatureFlagKeys).GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy) diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index f3a939650a..fcfa40d181 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -21,29 +21,29 @@ public class StripePaymentService : IPaymentService private const string SecretsManagerStandaloneDiscountId = "sm-standalone"; private readonly ITransactionRepository _transactionRepository; - private readonly IUserRepository _userRepository; private readonly ILogger _logger; private readonly Braintree.IBraintreeGateway _btGateway; private readonly ITaxRateRepository _taxRateRepository; private readonly IStripeAdapter _stripeAdapter; private readonly IGlobalSettings _globalSettings; + private readonly IFeatureService _featureService; public StripePaymentService( ITransactionRepository transactionRepository, - IUserRepository userRepository, ILogger logger, ITaxRateRepository taxRateRepository, IStripeAdapter stripeAdapter, Braintree.IBraintreeGateway braintreeGateway, - IGlobalSettings globalSettings) + IGlobalSettings globalSettings, + IFeatureService featureService) { _transactionRepository = transactionRepository; - _userRepository = userRepository; _logger = logger; _taxRateRepository = taxRateRepository; _stripeAdapter = stripeAdapter; _btGateway = braintreeGateway; _globalSettings = globalSettings; + _featureService = featureService; } public async Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, @@ -100,6 +100,28 @@ public class StripePaymentService : IPaymentService throw new GatewayException("Payment method is not supported at this time."); } + var pm5766AutomaticTaxIsEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax); + + if (!pm5766AutomaticTaxIsEnabled && + taxInfo != null && + !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && + !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode)) + { + var taxRateSearch = new TaxRate + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode + }; + var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch); + + // should only be one tax rate per country/zip combo + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null) + { + taxInfo.StripeTaxRateId = taxRate.Id; + } + } + var subCreateOptions = new OrganizationPurchaseSubscriptionOptions(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon , additionalSmSeats, additionalServiceAccount); @@ -153,7 +175,10 @@ public class StripePaymentService : IPaymentService subCreateOptions.AddExpand("latest_invoice.payment_intent"); subCreateOptions.Customer = customer.Id; - subCreateOptions.AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true }; + if (pm5766AutomaticTaxIsEnabled) + { + subCreateOptions.AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true }; + } subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions); if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null) @@ -236,11 +261,34 @@ public class StripePaymentService : IPaymentService throw new GatewayException("Could not find customer payment profile."); } - if (customer.Address is null && + var pm5766AutomaticTaxIsEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax); + var taxInfo = upgrade.TaxInfo; + + if (!pm5766AutomaticTaxIsEnabled && + taxInfo != null && + !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && + !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode)) + { + var taxRateSearch = new TaxRate + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode + }; + var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch); + + // should only be one tax rate per country/zip combo + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null) + { + taxInfo.StripeTaxRateId = taxRate.Id; + } + } + + if (pm5766AutomaticTaxIsEnabled && !string.IsNullOrEmpty(upgrade.TaxInfo?.BillingAddressCountry) && !string.IsNullOrEmpty(upgrade.TaxInfo?.BillingAddressPostalCode)) { - var addressOptions = new Stripe.AddressOptions + var addressOptions = new AddressOptions { Country = upgrade.TaxInfo.BillingAddressCountry, PostalCode = upgrade.TaxInfo.BillingAddressPostalCode, @@ -250,17 +298,20 @@ public class StripePaymentService : IPaymentService City = upgrade.TaxInfo.BillingAddressCity, State = upgrade.TaxInfo.BillingAddressState, }; - var customerUpdateOptions = new Stripe.CustomerUpdateOptions { Address = addressOptions }; + var customerUpdateOptions = new CustomerUpdateOptions { Address = addressOptions }; customerUpdateOptions.AddExpand("default_source"); customerUpdateOptions.AddExpand("invoice_settings.default_payment_method"); customer = await _stripeAdapter.CustomerUpdateAsync(org.GatewayCustomerId, customerUpdateOptions); } - var subCreateOptions = new OrganizationUpgradeSubscriptionOptions(customer.Id, org, plan, upgrade) + var subCreateOptions = new OrganizationUpgradeSubscriptionOptions(customer.Id, org, plan, upgrade); + + if (pm5766AutomaticTaxIsEnabled) { - DefaultTaxRates = new List(), - AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true } - }; + subCreateOptions.DefaultTaxRates = new List(); + subCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; + } + var (stripePaymentMethod, paymentMethodType) = IdentifyPaymentMethod(customer, subCreateOptions); var subscription = await ChargeForNewSubscriptionAsync(org, customer, false, @@ -457,6 +508,29 @@ public class StripePaymentService : IPaymentService Quantity = 1 }); + var pm5766AutomaticTaxIsEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax); + + if (!pm5766AutomaticTaxIsEnabled && + !string.IsNullOrWhiteSpace(taxInfo?.BillingAddressCountry) && + !string.IsNullOrWhiteSpace(taxInfo?.BillingAddressPostalCode)) + { + var taxRates = await _taxRateRepository.GetByLocationAsync( + new TaxRate + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode + } + ); + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null) + { + subCreateOptions.DefaultTaxRates = new List(1) + { + taxRate.Id + }; + } + } + if (additionalStorageGb > 0) { subCreateOptions.Items.Add(new Stripe.SubscriptionItemOptions @@ -466,7 +540,11 @@ public class StripePaymentService : IPaymentService }); } - subCreateOptions.AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true }; + if (pm5766AutomaticTaxIsEnabled) + { + subCreateOptions.DefaultTaxRates = new List(); + subCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; + } var subscription = await ChargeForNewSubscriptionAsync(user, customer, createdStripeCustomer, stripePaymentMethod, paymentMethodType, subCreateOptions, braintreeCustomer); @@ -504,11 +582,14 @@ public class StripePaymentService : IPaymentService var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new Stripe.UpcomingInvoiceOptions { Customer = customer.Id, - SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items), - AutomaticTax = - new Stripe.InvoiceAutomaticTaxOptions { Enabled = subCreateOptions.AutomaticTax.Enabled } + SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items) }); + if (_featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax)) + { + previewInvoice.AutomaticTax = new InvoiceAutomaticTax { Enabled = true }; + } + if (previewInvoice.AmountDue > 0) { var braintreeCustomerId = customer.Metadata != null && @@ -560,13 +641,22 @@ public class StripePaymentService : IPaymentService } else if (paymentMethodType == PaymentMethodType.Credit) { - var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new Stripe.UpcomingInvoiceOptions + var upcomingInvoiceOptions = new UpcomingInvoiceOptions { Customer = customer.Id, SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items), - AutomaticTax = - new Stripe.InvoiceAutomaticTaxOptions { Enabled = subCreateOptions.AutomaticTax.Enabled } - }); + SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates, + }; + + var pm5766AutomaticTaxIsEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax); + if (pm5766AutomaticTaxIsEnabled) + { + upcomingInvoiceOptions.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; + upcomingInvoiceOptions.SubscriptionDefaultTaxRates = new List(); + } + + var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(upcomingInvoiceOptions); + if (previewInvoice.AmountDue > 0) { throw new GatewayException("Your account does not have enough credit available."); @@ -575,7 +665,12 @@ public class StripePaymentService : IPaymentService subCreateOptions.OffSession = true; subCreateOptions.AddExpand("latest_invoice.payment_intent"); - subCreateOptions.AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true }; + + if (_featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax)) + { + subCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; + } + subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions); if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null) { @@ -675,16 +770,41 @@ public class StripePaymentService : IPaymentService DaysUntilDue = daysUntilDue ?? 1, CollectionMethod = "send_invoice", ProrationDate = prorationDate, - DefaultTaxRates = new List(), - AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true } }; + var pm5766AutomaticTaxIsEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax); + if (pm5766AutomaticTaxIsEnabled) + { + subUpdateOptions.DefaultTaxRates = new List(); + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; + } + if (!subscriptionUpdate.UpdateNeeded(sub)) { // No need to update subscription, quantity matches return null; } + if (!pm5766AutomaticTaxIsEnabled) + { + var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId); + + if (!string.IsNullOrWhiteSpace(customer?.Address?.Country) + && !string.IsNullOrWhiteSpace(customer?.Address?.PostalCode)) + { + var taxRates = await _taxRateRepository.GetByLocationAsync(new TaxRate + { + Country = customer.Address.Country, + PostalCode = customer.Address.PostalCode + }); + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null && !sub.DefaultTaxRates.Any(x => x.Equals(taxRate.Id))) + { + subUpdateOptions.DefaultTaxRates = new List(1) { taxRate.Id }; + } + } + } + string paymentIntentClientSecret = null; try { diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index 1dd20fcb5a..c07e77b1c3 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -214,7 +214,7 @@ public static class ServiceCollectionExtensions PrivateKey = globalSettings.Braintree.PrivateKey }; }); - services.AddSingleton(); + services.AddScoped(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/test/Core.Test/Services/StripePaymentServiceTests.cs b/test/Core.Test/Services/StripePaymentServiceTests.cs index 171fab0fb5..b4dbdaa7f7 100644 --- a/test/Core.Test/Services/StripePaymentServiceTests.cs +++ b/test/Core.Test/Services/StripePaymentServiceTests.cs @@ -701,6 +701,7 @@ public class StripePaymentServiceTests { organization.GatewaySubscriptionId = null; var stripeAdapter = sutProvider.GetDependency(); + var featureService = sutProvider.GetDependency(); stripeAdapter.CustomerGetAsync(default).ReturnsForAnyArgs(new Stripe.Customer { Id = "C-1", @@ -723,6 +724,7 @@ public class StripePaymentServiceTests AmountDue = 0 }); stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription { }); + featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax).Returns(true); var upgrade = new OrganizationUpgrade() {