diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index fff6b5271d..2fc44937a7 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -8,6 +8,7 @@ using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Context; @@ -82,7 +83,7 @@ public class ProviderService : IProviderService _pricingClient = pricingClient; } - public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null) + public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) { var owner = await _userService.GetUserByIdAsync(ownerUserId); if (owner == null) @@ -111,7 +112,20 @@ public class ProviderService : IProviderService { throw new BadRequestException("Both address and postal code are required to set up your provider."); } - var customer = await _providerBillingService.SetupCustomer(provider, taxInfo); + + var requireProviderPaymentMethodDuringSetup = + _featureService.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup); + + if (requireProviderPaymentMethodDuringSetup && tokenizedPaymentSource is not + { + Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal, + Token: not null and not "" + }) + { + throw new BadRequestException("A payment method is required to set up your provider."); + } + + var customer = await _providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); provider.GatewayCustomerId = customer.Id; var subscription = await _providerBillingService.SetupSubscription(provider); provider.GatewaySubscriptionId = subscription.Id; diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index bdfff079cf..f049d6c8df 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -6,9 +6,11 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing; +using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Repositories; @@ -21,14 +23,20 @@ using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; +using Braintree; using CsvHelper; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; +using static Bit.Core.Billing.Utilities; +using Customer = Stripe.Customer; +using Subscription = Stripe.Subscription; + namespace Bit.Commercial.Core.Billing; public class ProviderBillingService( + IBraintreeGateway braintreeGateway, IEventService eventService, IFeatureService featureService, IGlobalSettings globalSettings, @@ -39,6 +47,7 @@ public class ProviderBillingService( IProviderOrganizationRepository providerOrganizationRepository, IProviderPlanRepository providerPlanRepository, IProviderUserRepository providerUserRepository, + ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, ITaxService taxService, @@ -463,7 +472,8 @@ public class ProviderBillingService( public async Task SetupCustomer( Provider provider, - TaxInfo taxInfo) + TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource = null) { if (taxInfo is not { @@ -532,13 +542,97 @@ public class ProviderBillingService( options.Coupon = provider.DiscountId; } + var requireProviderPaymentMethodDuringSetup = + featureService.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup); + + var braintreeCustomerId = ""; + + if (requireProviderPaymentMethodDuringSetup) + { + if (tokenizedPaymentSource is not + { + Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal, + Token: not null and not "" + }) + { + logger.LogError("Cannot create customer for provider ({ProviderID}) without a payment method", provider.Id); + throw new BillingException(); + } + + var (type, token) = tokenizedPaymentSource; + + // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault + switch (type) + { + case PaymentMethodType.BankAccount: + { + var setupIntent = + (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = token })) + .FirstOrDefault(); + + if (setupIntent == null) + { + logger.LogError("Cannot create customer for provider ({ProviderID}) without a setup intent for their bank account", provider.Id); + throw new BillingException(); + } + + await setupIntentCache.Set(provider.Id, setupIntent.Id); + break; + } + case PaymentMethodType.Card: + { + options.PaymentMethod = token; + options.InvoiceSettings.DefaultPaymentMethod = token; + break; + } + case PaymentMethodType.PayPal: + { + braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(provider, token); + options.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId; + break; + } + } + } + try { return await stripeAdapter.CustomerCreateAsync(options); } - catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.TaxIdInvalid) + catch (StripeException stripeException) when (stripeException.StripeError?.Code == + StripeConstants.ErrorCodes.TaxIdInvalid) { - throw new BadRequestException("Your tax ID wasn't recognized for your selected country. Please ensure your country and tax ID are valid."); + await Revert(); + throw new BadRequestException( + "Your tax ID wasn't recognized for your selected country. Please ensure your country and tax ID are valid."); + } + catch + { + await Revert(); + throw; + } + + async Task Revert() + { + if (requireProviderPaymentMethodDuringSetup && tokenizedPaymentSource != null) + { + // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault + switch (tokenizedPaymentSource.Type) + { + case PaymentMethodType.BankAccount: + { + var setupIntentId = await setupIntentCache.Get(provider.Id); + await stripeAdapter.SetupIntentCancel(setupIntentId, + new SetupIntentCancelOptions { CancellationReason = "abandoned" }); + await setupIntentCache.Remove(provider.Id); + break; + } + case PaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): + { + await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId); + break; + } + } + } } } @@ -580,18 +674,38 @@ public class ProviderBillingService( }); } + var requireProviderPaymentMethodDuringSetup = + featureService.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup); + + var setupIntentId = await setupIntentCache.Get(provider.Id); + + var setupIntent = !string.IsNullOrEmpty(setupIntentId) + ? await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + { + Expand = ["payment_method"] + }) + : null; + + var usePaymentMethod = + requireProviderPaymentMethodDuringSetup && + (!string.IsNullOrEmpty(customer.InvoiceSettings.DefaultPaymentMethodId) || + customer.Metadata.ContainsKey(BraintreeCustomerIdKey) || + setupIntent.IsUnverifiedBankAccount()); + var subscriptionCreateOptions = new SubscriptionCreateOptions { - CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, + CollectionMethod = usePaymentMethod ? + StripeConstants.CollectionMethod.ChargeAutomatically : StripeConstants.CollectionMethod.SendInvoice, Customer = customer.Id, - DaysUntilDue = 30, + DaysUntilDue = usePaymentMethod ? null : 30, Items = subscriptionItemOptionsList, Metadata = new Dictionary { { "providerId", provider.Id.ToString() } }, OffSession = true, - ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations + ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, + TrialPeriodDays = usePaymentMethod ? 14 : null }; if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) @@ -607,7 +721,10 @@ public class ProviderBillingService( { var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); - if (subscription.Status == StripeConstants.SubscriptionStatus.Active) + if (subscription is + { + Status: StripeConstants.SubscriptionStatus.Active or StripeConstants.SubscriptionStatus.Trialing + }) { return subscription; } diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs index d2d82f47de..c66acfa8ce 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs @@ -1,5 +1,6 @@ using Bit.Commercial.Core.AdminConsole.Services; using Bit.Commercial.Core.Test.AdminConsole.AutoFixture; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; @@ -7,6 +8,7 @@ using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Context; @@ -38,7 +40,7 @@ public class ProviderServiceTests public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) { var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default)); + () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default, null)); Assert.Contains("Invalid owner.", exception.Message); } @@ -50,12 +52,85 @@ public class ProviderServiceTests userService.GetUserByIdAsync(user.Id).Returns(user); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default)); + () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default, null)); Assert.Contains("Invalid token.", exception.Message); } [Theory, BitAutoData] - public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TaxInfo taxInfo, + public async Task CompleteSetupAsync_InvalidTaxInfo_ThrowsBadRequestException( + User user, + Provider provider, + string key, + TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource, + [ProviderUser] ProviderUser providerUser, + SutProvider sutProvider) + { + providerUser.ProviderId = provider.Id; + providerUser.UserId = user.Id; + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + + sutProvider.Create(); + + var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + taxInfo.BillingAddressCountry = null; + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource)); + + Assert.Equal("Both address and postal code are required to set up your provider.", exception.Message); + } + + [Theory, BitAutoData] + public async Task CompleteSetupAsync_InvalidTokenizedPaymentSource_ThrowsBadRequestException( + User user, + Provider provider, + string key, + TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource, + [ProviderUser] ProviderUser providerUser, + SutProvider sutProvider) + { + providerUser.ProviderId = provider.Id; + providerUser.UserId = user.Id; + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + + sutProvider.Create(); + + var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay }; + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource)); + + Assert.Equal("A payment method is required to set up your provider.", exception.Message); + } + + [Theory, BitAutoData] + public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource, [ProviderUser] ProviderUser providerUser, SutProvider sutProvider) { @@ -75,7 +150,7 @@ public class ProviderServiceTests var providerBillingService = sutProvider.GetDependency(); var customer = new Customer { Id = "customer_id" }; - providerBillingService.SetupCustomer(provider, taxInfo).Returns(customer); + providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource).Returns(customer); var subscription = new Subscription { Id = "subscription_id" }; providerBillingService.SetupSubscription(provider).Returns(subscription); @@ -84,7 +159,7 @@ public class ProviderServiceTests var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo); + await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource); await sutProvider.GetDependency().Received().UpsertAsync(Arg.Is( p => diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index 2661a0eff6..1862692087 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -2,14 +2,17 @@ using System.Net; using Bit.Commercial.Core.Billing; using Bit.Commercial.Core.Billing.Models; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; @@ -24,11 +27,17 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; +using Braintree; using CsvHelper; using NSubstitute; +using NSubstitute.ExceptionExtensions; using Stripe; using Xunit; using static Bit.Core.Test.Billing.Utilities; +using Address = Stripe.Address; +using Customer = Stripe.Customer; +using PaymentMethod = Stripe.PaymentMethod; +using Subscription = Stripe.Subscription; namespace Bit.Commercial.Core.Test.Billing; @@ -833,7 +842,7 @@ public class ProviderBillingServiceTests } [Theory, BitAutoData] - public async Task SetupCustomer_Success( + public async Task SetupCustomer_NoPaymentMethod_Success( SutProvider sutProvider, Provider provider, TaxInfo taxInfo) @@ -877,6 +886,301 @@ public class ProviderBillingServiceTests Assert.Equivalent(expected, actual); } + [Theory, BitAutoData] + public async Task SetupCustomer_InvalidRequiredPaymentMethod_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay }; + + await ThrowsBillingExceptionAsync(() => + sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + } + + [Theory, BitAutoData] + public async Task SetupCustomer_WithBankAccount_Error_Reverts( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + stripeAdapter.SetupIntentList(Arg.Is(options => + options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([ + new SetupIntent { Id = "setup_intent_id" } + ]); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Throws(); + + sutProvider.GetDependency().Get(provider.Id).Returns("setup_intent_id"); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + + await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_id"); + + await stripeAdapter.Received(1).SetupIntentCancel("setup_intent_id", Arg.Is(options => + options.CancellationReason == "abandoned")); + + await sutProvider.GetDependency().Received(1).Remove(provider.Id); + } + + [Theory, BitAutoData] + public async Task SetupCustomer_WithPayPal_Error_Reverts( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token) + .Returns("braintree_customer_id"); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.Metadata["btCustomerId"] == "braintree_customer_id" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Throws(); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + + await sutProvider.GetDependency().Customer.Received(1).DeleteAsync("braintree_customer_id"); + } + + [Theory, BitAutoData] + public async Task SetupCustomer_WithBankAccount_Success( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var expected = new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + stripeAdapter.SetupIntentList(Arg.Is(options => + options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([ + new SetupIntent { Id = "setup_intent_id" } + ]); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Returns(expected); + + var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + + Assert.Equivalent(expected, actual); + + await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_id"); + } + + [Theory, BitAutoData] + public async Task SetupCustomer_WithPayPal_Success( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var expected = new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token) + .Returns("braintree_customer_id"); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.Metadata["btCustomerId"] == "braintree_customer_id" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Returns(expected); + + var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + + Assert.Equivalent(expected, actual); + } + + [Theory, BitAutoData] + public async Task SetupCustomer_WithCard_Success( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var expected = new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.PaymentMethod == tokenizedPaymentSource.Token && + o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Returns(expected); + + var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + + Assert.Equivalent(expected, actual); + } + [Theory, BitAutoData] public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid( SutProvider sutProvider, @@ -1044,7 +1348,7 @@ public class ProviderBillingServiceTests } [Theory, BitAutoData] - public async Task SetupSubscription_Succeeds( + public async Task SetupSubscription_SendInvoice_Succeeds( SutProvider sutProvider, Provider provider) { @@ -1127,6 +1431,303 @@ public class ProviderBillingServiceTests Assert.Equivalent(expected, actual); } + [Theory, BitAutoData] + public async Task SetupSubscription_ChargeAutomatically_HasCard_Succeeds( + SutProvider sutProvider, + Provider provider) + { + provider.Type = ProviderType.Msp; + provider.GatewaySubscriptionId = null; + + var customer = new Customer + { + Id = "customer_id", + InvoiceSettings = new CustomerInvoiceSettings + { + DefaultPaymentMethodId = "pm_123" + }, + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer); + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + }, + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + } + }; + + foreach (var plan in providerPlans) + { + sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) + .Returns(StaticStore.GetPlan(plan.PlanType)); + } + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; + + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( + Arg.Is(options => + options.Customer == "customer_id") + , Arg.Is(p => p == customer))) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && + sub.Customer == "customer_id" && + sub.DaysUntilDue == null && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + sub.TrialPeriodDays == 14)).Returns(expected); + + var actual = await sutProvider.Sut.SetupSubscription(provider); + + Assert.Equivalent(expected, actual); + } + + [Theory, BitAutoData] + public async Task SetupSubscription_ChargeAutomatically_HasBankAccount_Succeeds( + SutProvider sutProvider, + Provider provider) + { + provider.Type = ProviderType.Msp; + provider.GatewaySubscriptionId = null; + + var customer = new Customer + { + Id = "customer_id", + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary(), + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer); + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + }, + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + } + }; + + foreach (var plan in providerPlans) + { + sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) + .Returns(StaticStore.GetPlan(plan.PlanType)); + } + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; + + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( + Arg.Is(options => + options.Customer == "customer_id") + , Arg.Is(p => p == customer))) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + const string setupIntentId = "seti_123"; + + sutProvider.GetDependency().Get(provider.Id).Returns(setupIntentId); + + sutProvider.GetDependency().SetupIntentGet(setupIntentId, Arg.Is(options => + options.Expand.Contains("payment_method"))).Returns(new SetupIntent + { + Id = setupIntentId, + Status = "requires_action", + NextAction = new SetupIntentNextAction + { + VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits() + }, + PaymentMethod = new PaymentMethod + { + UsBankAccount = new PaymentMethodUsBankAccount() + } + }); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && + sub.Customer == "customer_id" && + sub.DaysUntilDue == null && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + sub.TrialPeriodDays == 14)).Returns(expected); + + var actual = await sutProvider.Sut.SetupSubscription(provider); + + Assert.Equivalent(expected, actual); + } + + [Theory, BitAutoData] + public async Task SetupSubscription_ChargeAutomatically_HasPayPal_Succeeds( + SutProvider sutProvider, + Provider provider) + { + provider.Type = ProviderType.Msp; + provider.GatewaySubscriptionId = null; + + var customer = new Customer + { + Id = "customer_id", + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary + { + ["btCustomerId"] = "braintree_customer_id" + }, + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer); + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + }, + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + } + }; + + foreach (var plan in providerPlans) + { + sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) + .Returns(StaticStore.GetPlan(plan.PlanType)); + } + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; + + sutProvider.GetDependency() + .When(x => x.SetCreateOptions( + Arg.Is(options => + options.Customer == "customer_id") + , Arg.Is(p => p == customer))) + .Do(x => + { + x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }; + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && + sub.Customer == "customer_id" && + sub.DaysUntilDue == null && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + sub.TrialPeriodDays == 14)).Returns(expected); + + var actual = await sutProvider.Sut.SetupSubscription(provider); + + Assert.Equivalent(expected, actual); + } + #endregion #region UpdateSeatMinimums diff --git a/src/Api/AdminConsole/Controllers/ProvidersController.cs b/src/Api/AdminConsole/Controllers/ProvidersController.cs index be119744b3..b6933da0c9 100644 --- a/src/Api/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Api/AdminConsole/Controllers/ProvidersController.cs @@ -84,22 +84,22 @@ public class ProvidersController : Controller var userId = _userService.GetProperUserId(User).Value; - var taxInfo = model.TaxInfo != null - ? new TaxInfo - { - BillingAddressCountry = model.TaxInfo.Country, - BillingAddressPostalCode = model.TaxInfo.PostalCode, - TaxIdNumber = model.TaxInfo.TaxId, - BillingAddressLine1 = model.TaxInfo.Line1, - BillingAddressLine2 = model.TaxInfo.Line2, - BillingAddressCity = model.TaxInfo.City, - BillingAddressState = model.TaxInfo.State - } - : null; + var taxInfo = new TaxInfo + { + BillingAddressCountry = model.TaxInfo.Country, + BillingAddressPostalCode = model.TaxInfo.PostalCode, + TaxIdNumber = model.TaxInfo.TaxId, + BillingAddressLine1 = model.TaxInfo.Line1, + BillingAddressLine2 = model.TaxInfo.Line2, + BillingAddressCity = model.TaxInfo.City, + BillingAddressState = model.TaxInfo.State + }; + + var tokenizedPaymentSource = model.PaymentSource?.ToDomain(); var response = await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key, - taxInfo); + taxInfo, tokenizedPaymentSource); return new ProviderResponseModel(response); } diff --git a/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs b/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs index 5e10807c69..697077c9b6 100644 --- a/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs @@ -1,5 +1,6 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json.Serialization; +using Bit.Api.Billing.Models.Requests; using Bit.Api.Models.Request; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Utilities; @@ -23,7 +24,9 @@ public class ProviderSetupRequestModel public string Token { get; set; } [Required] public string Key { get; set; } + [Required] public ExpandedTaxInfoUpdateRequestModel TaxInfo { get; set; } + public TokenizedPaymentSourceRequestBody PaymentSource { get; set; } public virtual Provider ToProvider(Provider provider) { diff --git a/src/Core/AdminConsole/Services/IProviderService.cs b/src/Core/AdminConsole/Services/IProviderService.cs index 8999b3cb81..e4b6f3aabd 100644 --- a/src/Core/AdminConsole/Services/IProviderService.cs +++ b/src/Core/AdminConsole/Services/IProviderService.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; +using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Models.Business; @@ -7,7 +8,8 @@ namespace Bit.Core.AdminConsole.Services; public interface IProviderService { - Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null); + Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource = null); Task UpdateAsync(Provider provider, bool updateBilling = false); Task> InviteUserAsync(ProviderUserInvite invite); diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs b/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs index bd3a757663..94c1096b58 100644 --- a/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs +++ b/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; +using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Models.Business; @@ -7,7 +8,7 @@ namespace Bit.Core.AdminConsole.Services.NoopImplementations; public class NoopProviderService : IProviderService { - public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null) => throw new NotImplementedException(); + public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) => throw new NotImplementedException(); public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); diff --git a/src/Core/Billing/Services/IProviderBillingService.cs b/src/Core/Billing/Services/IProviderBillingService.cs index 64585f3361..6ed8910dd8 100644 --- a/src/Core/Billing/Services/IProviderBillingService.cs +++ b/src/Core/Billing/Services/IProviderBillingService.cs @@ -79,10 +79,12 @@ public interface IProviderBillingService /// /// The to create a Stripe customer for. /// The to use for calculating the customer's automatic tax. + /// The (ex. Credit Card) to attach to the customer. /// The newly created for the . Task SetupCustomer( Provider provider, - TaxInfo taxInfo); + TaxInfo taxInfo, + TokenizedPaymentSource tokenizedPaymentSource = null); /// /// For use during the provider setup process, this method starts a Stripe for the given . diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index a05b89a94f..0d4c105b2d 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -149,6 +149,7 @@ public static class FeatureFlagKeys public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; public const string PM18770_EnableOrganizationBusinessUnitConversion = "pm-18770-enable-organization-business-unit-conversion"; public const string PM199566_UpdateMSPToChargeAutomatically = "pm-199566-update-msp-to-charge-automatically"; + public const string PM19956_RequireProviderPaymentMethodDuringSetup = "pm-19956-require-provider-payment-method-during-setup"; /* Data Insights and Reporting Team */ public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application";