1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-04 20:50:21 -05:00

[AC-2888] Improve consolidated billing error handling (#4548)

* Fix error handling in provider setup process

This update ensures that when 'enable-consolidated-billing' is on, any exception thrown during the Stripe customer or subscription setup process for the provider will block the remainder of the setup process so the provider does not enter an invalid state

* Refactor the way BillingException is thrown

Made it simpler to just use the exception constructor and also ensured it was added to the exception handling middleware so it could provide a simple response to the client

* Handle all Stripe exceptions in exception handling middleware

* Fixed error response output for billing's provider controllers

* Cleaned up billing owned provider controllers

Changes were made based on feature updates by product and stuff that's no longer needed. No need to expose sensitive endpoints when they're not being used.

* Reafctored get invoices

Removed unnecssarily bloated method from SubscriberService

* Updated error handling for generating the client invoice report

* Moved get provider subscription to controller

This is only used once and the service layer doesn't seem like the correct choice anymore when thinking about error handling with retrieval

* Handled bad request for update tax information

* Split out Stripe configuration from unauthorization

* Run dotnet format

* Addison's feedback
This commit is contained in:
Alex Morask 2024-07-31 09:26:44 -04:00 committed by GitHub
parent 85ddd080cb
commit 398741cec4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 777 additions and 1260 deletions

View File

@ -9,6 +9,7 @@ using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Services;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
@ -45,6 +46,7 @@ public class ProviderService : IProviderService
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
private readonly IDataProtectorTokenFactory<ProviderDeleteTokenable> _providerDeleteTokenDataFactory; private readonly IDataProtectorTokenFactory<ProviderDeleteTokenable> _providerDeleteTokenDataFactory;
private readonly IApplicationCacheService _applicationCacheService; private readonly IApplicationCacheService _applicationCacheService;
private readonly IProviderBillingService _providerBillingService;
public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository,
@ -53,7 +55,7 @@ public class ProviderService : IProviderService
IOrganizationRepository organizationRepository, GlobalSettings globalSettings, IOrganizationRepository organizationRepository, GlobalSettings globalSettings,
ICurrentContext currentContext, IStripeAdapter stripeAdapter, IFeatureService featureService, ICurrentContext currentContext, IStripeAdapter stripeAdapter, IFeatureService featureService,
IDataProtectorTokenFactory<ProviderDeleteTokenable> providerDeleteTokenDataFactory, IDataProtectorTokenFactory<ProviderDeleteTokenable> providerDeleteTokenDataFactory,
IApplicationCacheService applicationCacheService) IApplicationCacheService applicationCacheService, IProviderBillingService providerBillingService)
{ {
_providerRepository = providerRepository; _providerRepository = providerRepository;
_providerUserRepository = providerUserRepository; _providerUserRepository = providerUserRepository;
@ -71,9 +73,10 @@ public class ProviderService : IProviderService
_featureService = featureService; _featureService = featureService;
_providerDeleteTokenDataFactory = providerDeleteTokenDataFactory; _providerDeleteTokenDataFactory = providerDeleteTokenDataFactory;
_applicationCacheService = applicationCacheService; _applicationCacheService = applicationCacheService;
_providerBillingService = providerBillingService;
} }
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null)
{ {
var owner = await _userService.GetUserByIdAsync(ownerUserId); var owner = await _userService.GetUserByIdAsync(ownerUserId);
if (owner == null) if (owner == null)
@ -98,8 +101,24 @@ public class ProviderService : IProviderService
throw new BadRequestException("Invalid owner."); throw new BadRequestException("Invalid owner.");
} }
provider.Status = ProviderStatusType.Created; if (!_featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
await _providerRepository.UpsertAsync(provider); {
provider.Status = ProviderStatusType.Created;
await _providerRepository.UpsertAsync(provider);
}
else
{
if (taxInfo == null || string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode))
{
throw new BadRequestException("Both address and postal code are required to set up your provider.");
}
var customer = await _providerBillingService.SetupCustomer(provider, taxInfo);
provider.GatewayCustomerId = customer.Id;
var subscription = await _providerBillingService.SetupSubscription(provider);
provider.GatewaySubscriptionId = subscription.Id;
provider.Status = ProviderStatusType.Billable;
await _providerRepository.UpsertAsync(provider);
}
providerUser.Key = key; providerUser.Key = key;
await _providerUserRepository.ReplaceAsync(providerUser); await _providerUserRepository.ReplaceAsync(providerUser);

View File

@ -9,11 +9,11 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Entities; using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
@ -22,7 +22,6 @@ using Bit.Core.Utilities;
using CsvHelper; using CsvHelper;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Stripe; using Stripe;
using static Bit.Core.Billing.Utilities;
namespace Bit.Commercial.Core.Billing; namespace Bit.Commercial.Core.Billing;
@ -69,67 +68,6 @@ public class ProviderBillingService(
await organizationRepository.ReplaceAsync(organization); await organizationRepository.ReplaceAsync(organization);
} }
public async Task CreateCustomer(
Provider provider,
TaxInfo taxInfo)
{
ArgumentNullException.ThrowIfNull(provider);
ArgumentNullException.ThrowIfNull(taxInfo);
if (string.IsNullOrEmpty(taxInfo.BillingAddressCountry) ||
string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode))
{
logger.LogError("Cannot create customer for provider ({ProviderID}) without both a country and postal code", provider.Id);
throw ContactSupport();
}
var providerDisplayName = provider.DisplayName();
var customerCreateOptions = new CustomerCreateOptions
{
Address = new AddressOptions
{
Country = taxInfo.BillingAddressCountry,
PostalCode = taxInfo.BillingAddressPostalCode,
Line1 = taxInfo.BillingAddressLine1,
Line2 = taxInfo.BillingAddressLine2,
City = taxInfo.BillingAddressCity,
State = taxInfo.BillingAddressState
},
Description = provider.DisplayBusinessName(),
Email = provider.BillingEmail,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = provider.SubscriberType(),
Value = providerDisplayName.Length <= 30
? providerDisplayName
: providerDisplayName[..30]
}
]
},
Metadata = new Dictionary<string, string>
{
{ "region", globalSettings.BaseServiceUri.CloudRegion }
},
TaxIdData = taxInfo.HasTaxId ?
[
new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber }
]
: null
};
var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
provider.GatewayCustomerId = customer.Id;
await providerRepository.ReplaceAsync(provider);
}
public async Task CreateCustomerForClientOrganization( public async Task CreateCustomerForClientOrganization(
Provider provider, Provider provider,
Organization organization) Organization organization)
@ -204,15 +142,14 @@ public class ProviderBillingService(
public async Task<byte[]> GenerateClientInvoiceReport( public async Task<byte[]> GenerateClientInvoiceReport(
string invoiceId) string invoiceId)
{ {
if (string.IsNullOrEmpty(invoiceId)) ArgumentException.ThrowIfNullOrEmpty(invoiceId);
{
throw new ArgumentNullException(nameof(invoiceId));
}
var invoiceItems = await providerInvoiceItemRepository.GetByInvoiceId(invoiceId); var invoiceItems = await providerInvoiceItemRepository.GetByInvoiceId(invoiceId);
if (invoiceItems.Count == 0) if (invoiceItems.Count == 0)
{ {
logger.LogError("No provider invoice item records were found for invoice ({InvoiceID})", invoiceId);
return null; return null;
} }
@ -245,14 +182,14 @@ public class ProviderBillingService(
"Could not find provider ({ID}) when retrieving assigned seat total", "Could not find provider ({ID}) when retrieving assigned seat total",
providerId); providerId);
throw ContactSupport(); throw new BillingException();
} }
if (provider.Type == ProviderType.Reseller) if (provider.Type == ProviderType.Reseller)
{ {
logger.LogError("Assigned seats cannot be retrieved for reseller-type provider ({ID})", providerId); logger.LogError("Assigned seats cannot be retrieved for reseller-type provider ({ID})", providerId);
throw ContactSupport("Consolidated billing does not support reseller-type providers"); throw new BillingException();
} }
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId);
@ -264,39 +201,6 @@ public class ProviderBillingService(
.Sum(providerOrganization => providerOrganization.Seats ?? 0); .Sum(providerOrganization => providerOrganization.Seats ?? 0);
} }
public async Task<ConsolidatedBillingSubscriptionDTO> GetConsolidatedBillingSubscription(
Provider provider)
{
ArgumentNullException.ThrowIfNull(provider);
var subscription = await subscriberService.GetSubscription(provider, new SubscriptionGetOptions
{
Expand = ["customer", "test_clock"]
});
if (subscription == null)
{
return null;
}
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
var configuredProviderPlans = providerPlans
.Where(providerPlan => providerPlan.IsConfigured())
.Select(ConfiguredProviderPlanDTO.From)
.ToList();
var taxInformation = await subscriberService.GetTaxInformation(provider);
var suspension = await GetSuspensionAsync(stripeAdapter, subscription);
return new ConsolidatedBillingSubscriptionDTO(
configuredProviderPlans,
subscription,
taxInformation,
suspension);
}
public async Task ScaleSeats( public async Task ScaleSeats(
Provider provider, Provider provider,
PlanType planType, PlanType planType,
@ -308,14 +212,14 @@ public class ProviderBillingService(
{ {
logger.LogError("Non-MSP provider ({ProviderID}) cannot scale their seats", provider.Id); logger.LogError("Non-MSP provider ({ProviderID}) cannot scale their seats", provider.Id);
throw ContactSupport(); throw new BillingException();
} }
if (!planType.SupportsConsolidatedBilling()) if (!planType.SupportsConsolidatedBilling())
{ {
logger.LogError("Cannot scale provider ({ProviderID}) seats for plan type {PlanType} as it does not support consolidated billing", provider.Id, planType.ToString()); logger.LogError("Cannot scale provider ({ProviderID}) seats for plan type {PlanType} as it does not support consolidated billing", provider.Id, planType.ToString());
throw ContactSupport(); throw new BillingException();
} }
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
@ -326,7 +230,7 @@ public class ProviderBillingService(
{ {
logger.LogError("Cannot scale provider ({ProviderID}) seats for plan type {PlanType} when their matching provider plan is not configured", provider.Id, planType); logger.LogError("Cannot scale provider ({ProviderID}) seats for plan type {PlanType} when their matching provider plan is not configured", provider.Id, planType);
throw ContactSupport(); throw new BillingException();
} }
var seatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0); var seatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0);
@ -362,7 +266,7 @@ public class ProviderBillingService(
{ {
logger.LogError("Service user for provider ({ProviderID}) cannot scale a provider's seat count over the seat minimum", provider.Id); logger.LogError("Service user for provider ({ProviderID}) cannot scale a provider's seat count over the seat minimum", provider.Id);
throw ContactSupport(); throw new BillingException();
} }
await update( await update(
@ -393,7 +297,64 @@ public class ProviderBillingService(
} }
} }
public async Task StartSubscription( public async Task<Customer> SetupCustomer(
Provider provider,
TaxInfo taxInfo)
{
ArgumentNullException.ThrowIfNull(provider);
ArgumentNullException.ThrowIfNull(taxInfo);
if (string.IsNullOrEmpty(taxInfo.BillingAddressCountry) ||
string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode))
{
logger.LogError("Cannot create customer for provider ({ProviderID}) without both a country and postal code", provider.Id);
throw new BillingException();
}
var providerDisplayName = provider.DisplayName();
var customerCreateOptions = new CustomerCreateOptions
{
Address = new AddressOptions
{
Country = taxInfo.BillingAddressCountry,
PostalCode = taxInfo.BillingAddressPostalCode,
Line1 = taxInfo.BillingAddressLine1,
Line2 = taxInfo.BillingAddressLine2,
City = taxInfo.BillingAddressCity,
State = taxInfo.BillingAddressState
},
Description = provider.DisplayBusinessName(),
Email = provider.BillingEmail,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = provider.SubscriberType(),
Value = providerDisplayName?.Length <= 30
? providerDisplayName
: providerDisplayName?[..30]
}
]
},
Metadata = new Dictionary<string, string>
{
{ "region", globalSettings.BaseServiceUri.CloudRegion }
},
TaxIdData = taxInfo.HasTaxId ?
[
new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber }
]
: null
};
return await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
}
public async Task<Subscription> SetupSubscription(
Provider provider) Provider provider)
{ {
ArgumentNullException.ThrowIfNull(provider); ArgumentNullException.ThrowIfNull(provider);
@ -406,7 +367,7 @@ public class ProviderBillingService(
{ {
logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans", provider.Id); logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans", provider.Id);
throw ContactSupport(); throw new BillingException();
} }
var subscriptionItemOptionsList = new List<SubscriptionItemOptions>(); var subscriptionItemOptionsList = new List<SubscriptionItemOptions>();
@ -418,7 +379,7 @@ public class ProviderBillingService(
{ {
logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured Teams plan", provider.Id); logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured Teams plan", provider.Id);
throw ContactSupport(); throw new BillingException();
} }
var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
@ -436,7 +397,7 @@ public class ProviderBillingService(
{ {
logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured Enterprise plan", provider.Id); logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured Enterprise plan", provider.Id);
throw ContactSupport(); throw new BillingException();
} }
var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly);
@ -465,22 +426,27 @@ public class ProviderBillingService(
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations
}; };
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); try
provider.GatewaySubscriptionId = subscription.Id;
if (subscription.Status == StripeConstants.SubscriptionStatus.Incomplete)
{ {
await providerRepository.ReplaceAsync(provider); var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
logger.LogError("Started incomplete provider ({ProviderID}) subscription ({SubscriptionID})", provider.Id, subscription.Id); if (subscription.Status == StripeConstants.SubscriptionStatus.Active)
{
return subscription;
}
throw ContactSupport(); logger.LogError(
"Newly created provider ({ProviderID}) subscription ({SubscriptionID}) has inactive status: {Status}",
provider.Id,
subscription.Id,
subscription.Status);
throw new BillingException();
}
catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid)
{
throw new BadRequestException("Your location wasn't recognized. Please ensure your country and postal code are valid.");
} }
provider.Status = ProviderStatusType.Billable;
await providerRepository.ReplaceAsync(provider);
} }
private Func<int, int, Task> CurrySeatScalingUpdate( private Func<int, int, Task> CurrySeatScalingUpdate(

View File

@ -8,6 +8,7 @@ using Bit.Core.AdminConsole.Models.Business.Provider;
using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Models.Business.Tokenables;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Services;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
@ -81,6 +82,51 @@ public class ProviderServiceTests
.ReplaceAsync(Arg.Is<ProviderUser>(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key)); .ReplaceAsync(Arg.Is<ProviderUser>(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key));
} }
[Theory, BitAutoData]
public async Task CompleteSetupAsync_ConsolidatedBilling_Success(User user, Provider provider, string key, TaxInfo taxInfo,
[ProviderUser(ProviderUserStatusType.Confirmed, ProviderUserType.ProviderAdmin)] ProviderUser providerUser,
SutProvider<ProviderService> sutProvider)
{
providerUser.ProviderId = provider.Id;
providerUser.UserId = user.Id;
var userService = sutProvider.GetDependency<IUserService>();
userService.GetUserByIdAsync(user.Id).Returns(user);
var providerUserRepository = sutProvider.GetDependency<IProviderUserRepository>();
providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser);
var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName");
var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
sutProvider.GetDependency<IDataProtectionProvider>().CreateProtector("ProviderServiceDataProtector")
.Returns(protector);
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
var providerBillingService = sutProvider.GetDependency<IProviderBillingService>();
var customer = new Customer { Id = "customer_id" };
providerBillingService.SetupCustomer(provider, taxInfo).Returns(customer);
var subscription = new Subscription { Id = "subscription_id" };
providerBillingService.SetupSubscription(provider).Returns(subscription);
sutProvider.Create();
var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo);
await sutProvider.GetDependency<IProviderRepository>().Received().UpsertAsync(Arg.Is<Provider>(
p =>
p.GatewayCustomerId == customer.Id &&
p.GatewaySubscriptionId == subscription.Id &&
p.Status == ProviderStatusType.Billable));
await sutProvider.GetDependency<IProviderUserRepository>().Received()
.ReplaceAsync(Arg.Is<ProviderUser>(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key));
}
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdateAsync_ProviderIdIsInvalid_Throws(Provider provider, SutProvider<ProviderService> sutProvider) public async Task UpdateAsync_ProviderIdIsInvalid_Throws(Provider provider, SutProvider<ProviderService> sutProvider)
{ {

View File

@ -11,7 +11,6 @@ using Bit.Core.Billing;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Entities; using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Context; using Bit.Core.Context;
@ -87,7 +86,7 @@ public class ProviderBillingServiceTests
{ {
organization.PlanType = PlanType.FamiliesAnnually; organization.PlanType = PlanType.FamiliesAnnually;
await ThrowsContactSupportAsync(() => await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats));
} }
@ -105,7 +104,7 @@ public class ProviderBillingServiceTests
new() { Id = Guid.NewGuid(), PlanType = PlanType.TeamsMonthly, ProviderId = provider.Id } new() { Id = Guid.NewGuid(), PlanType = PlanType.TeamsMonthly, ProviderId = provider.Id }
}); });
await ThrowsContactSupportAsync(() => await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats));
} }
@ -247,7 +246,7 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(provider.Id).Returns(false); sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(provider.Id).Returns(false);
await ThrowsContactSupportAsync(() => await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats));
} }
@ -493,105 +492,6 @@ public class ProviderBillingServiceTests
#endregion #endregion
#region CreateCustomer
[Theory, BitAutoData]
public async Task CreateCustomer_NullProvider_ThrowsArgumentNullException(
SutProvider<ProviderBillingService> sutProvider,
TaxInfo taxInfo) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.CreateCustomer(null, taxInfo));
[Theory, BitAutoData]
public async Task CreateCustomer_NullTaxInfo_ThrowsArgumentNullException(
SutProvider<ProviderBillingService> sutProvider,
Provider provider) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.CreateCustomer(provider, null));
[Theory, BitAutoData]
public async Task CreateCustomer_MissingCountry_ContactSupport(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
{
taxInfo.BillingAddressCountry = null;
await ThrowsContactSupportAsync(() => sutProvider.Sut.CreateCustomer(provider, taxInfo));
await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceiveWithAnyArgs()
.CustomerGetAsync(Arg.Any<string>(), Arg.Any<CustomerGetOptions>());
}
[Theory, BitAutoData]
public async Task CreateCustomer_MissingPostalCode_ContactSupport(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
{
taxInfo.BillingAddressCountry = null;
await ThrowsContactSupportAsync(() => sutProvider.Sut.CreateCustomer(provider, taxInfo));
await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceiveWithAnyArgs()
.CustomerGetAsync(Arg.Any<string>(), Arg.Any<CustomerGetOptions>());
}
[Theory, BitAutoData]
public async Task CreateCustomer_Success(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
{
provider.Name = "MSP";
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber))
.Returns(new Customer
{
Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
});
await sutProvider.Sut.CreateCustomer(provider, taxInfo);
await stripeAdapter.Received(1).CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber));
await sutProvider.GetDependency<IProviderRepository>()
.ReplaceAsync(Arg.Is<Provider>(p => p.GatewayCustomerId == "customer_id"));
}
#endregion
#region CreateCustomerForClientOrganization #region CreateCustomerForClientOrganization
[Theory, BitAutoData] [Theory, BitAutoData]
@ -777,7 +677,7 @@ public class ProviderBillingServiceTests
public async Task GetAssignedSeatTotalForPlanOrThrow_NullProvider_ContactSupport( public async Task GetAssignedSeatTotalForPlanOrThrow_NullProvider_ContactSupport(
Guid providerId, Guid providerId,
SutProvider<ProviderBillingService> sutProvider) SutProvider<ProviderBillingService> sutProvider)
=> await ThrowsContactSupportAsync(() => => await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.GetAssignedSeatTotalForPlanOrThrow(providerId, PlanType.TeamsMonthly)); sutProvider.Sut.GetAssignedSeatTotalForPlanOrThrow(providerId, PlanType.TeamsMonthly));
[Theory, BitAutoData] [Theory, BitAutoData]
@ -790,9 +690,8 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId).Returns(provider); sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId).Returns(provider);
await ThrowsContactSupportAsync( await ThrowsBillingExceptionAsync(
() => sutProvider.Sut.GetAssignedSeatTotalForPlanOrThrow(providerId, PlanType.TeamsMonthly), () => sutProvider.Sut.GetAssignedSeatTotalForPlanOrThrow(providerId, PlanType.TeamsMonthly));
internalMessage: "Consolidated billing does not support reseller-type providers");
} }
[Theory, BitAutoData] [Theory, BitAutoData]
@ -836,197 +735,100 @@ public class ProviderBillingServiceTests
#endregion #endregion
#region GetConsolidatedBillingSubscription #region SetupCustomer
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetConsolidatedBillingSubscription_NullProvider_ThrowsArgumentNullException( public async Task SetupCustomer_NullProvider_ThrowsArgumentNullException(
SutProvider<ProviderBillingService> sutProvider) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.GetConsolidatedBillingSubscription(null));
[Theory, BitAutoData]
public async Task GetConsolidatedBillingSubscription_NullSubscription_ReturnsNull(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider) TaxInfo taxInfo) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.SetupCustomer(null, taxInfo));
[Theory, BitAutoData]
public async Task SetupCustomer_NullTaxInfo_ThrowsArgumentNullException(
SutProvider<ProviderBillingService> sutProvider,
Provider provider) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.SetupCustomer(provider, null));
[Theory, BitAutoData]
public async Task SetupCustomer_MissingCountry_ContactSupport(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
{ {
var consolidatedBillingSubscription = await sutProvider.Sut.GetConsolidatedBillingSubscription(provider); taxInfo.BillingAddressCountry = null;
Assert.Null(consolidatedBillingSubscription); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo));
await sutProvider.GetDependency<ISubscriberService>().Received(1).GetSubscription( await sutProvider.GetDependency<IStripeAdapter>()
provider, .DidNotReceiveWithAnyArgs()
Arg.Is<SubscriptionGetOptions>( .CustomerGetAsync(Arg.Any<string>(), Arg.Any<CustomerGetOptions>());
options => options.Expand.Count == 2 && options.Expand.First() == "customer" && options.Expand.Last() == "test_clock"));
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetConsolidatedBillingSubscription_Active_NoSuspension_Success( public async Task SetupCustomer_MissingPostalCode_ContactSupport(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider) Provider provider,
TaxInfo taxInfo)
{ {
var subscriberService = sutProvider.GetDependency<ISubscriberService>(); taxInfo.BillingAddressCountry = null;
var subscription = new Subscription await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo));
{
Status = "active"
};
subscriberService.GetSubscription(provider, Arg.Is<SubscriptionGetOptions>( await sutProvider.GetDependency<IStripeAdapter>()
options => options.Expand.Count == 2 && options.Expand.First() == "customer" && options.Expand.Last() == "test_clock")).Returns(subscription); .DidNotReceiveWithAnyArgs()
.CustomerGetAsync(Arg.Any<string>(), Arg.Any<CustomerGetOptions>());
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
var enterprisePlan = new ProviderPlan
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.EnterpriseMonthly,
SeatMinimum = 100,
PurchasedSeats = 0,
AllocatedSeats = 0
};
var teamsPlan = new ProviderPlan
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.TeamsMonthly,
SeatMinimum = 50,
PurchasedSeats = 10,
AllocatedSeats = 60
};
var providerPlans = new List<ProviderPlan> { enterprisePlan, teamsPlan, };
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var taxInformation =
new TaxInformationDTO("US", "12345", "123456789", "123 Example St.", null, "Example Town", "NY");
subscriberService.GetTaxInformation(provider).Returns(taxInformation);
var (gotProviderPlans, gotSubscription, gotTaxInformation, gotSuspension) = await sutProvider.Sut.GetConsolidatedBillingSubscription(provider);
Assert.Equal(2, gotProviderPlans.Count);
var configuredEnterprisePlan =
gotProviderPlans.FirstOrDefault(configuredPlan =>
configuredPlan.PlanType == PlanType.EnterpriseMonthly);
var configuredTeamsPlan =
gotProviderPlans.FirstOrDefault(configuredPlan =>
configuredPlan.PlanType == PlanType.TeamsMonthly);
Compare(enterprisePlan, configuredEnterprisePlan);
Compare(teamsPlan, configuredTeamsPlan);
Assert.Equivalent(subscription, gotSubscription);
Assert.Equivalent(taxInformation, gotTaxInformation);
Assert.Null(gotSuspension);
return;
void Compare(ProviderPlan providerPlan, ConfiguredProviderPlanDTO configuredProviderPlan)
{
Assert.NotNull(configuredProviderPlan);
Assert.Equal(providerPlan.Id, configuredProviderPlan.Id);
Assert.Equal(providerPlan.ProviderId, configuredProviderPlan.ProviderId);
Assert.Equal(providerPlan.SeatMinimum!.Value, configuredProviderPlan.SeatMinimum);
Assert.Equal(providerPlan.PurchasedSeats!.Value, configuredProviderPlan.PurchasedSeats);
Assert.Equal(providerPlan.AllocatedSeats!.Value, configuredProviderPlan.AssignedSeats);
}
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetConsolidatedBillingSubscription_PastDue_HasSuspension_Success( public async Task SetupCustomer_Success(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider) Provider provider,
TaxInfo taxInfo)
{ {
var subscriberService = sutProvider.GetDependency<ISubscriberService>(); provider.Name = "MSP";
var subscription = new Subscription taxInfo.BillingAddressCountry = "AD";
{
Id = "subscription_id",
Status = "past_due",
CollectionMethod = "send_invoice"
};
subscriberService.GetSubscription(provider, Arg.Is<SubscriptionGetOptions>(
options => options.Expand.Count == 2 && options.Expand.First() == "customer" && options.Expand.Last() == "test_clock")).Returns(subscription);
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
var enterprisePlan = new ProviderPlan
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.EnterpriseMonthly,
SeatMinimum = 100,
PurchasedSeats = 0,
AllocatedSeats = 0
};
var teamsPlan = new ProviderPlan
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.TeamsMonthly,
SeatMinimum = 50,
PurchasedSeats = 10,
AllocatedSeats = 60
};
var providerPlans = new List<ProviderPlan> { enterprisePlan, teamsPlan, };
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var taxInformation =
new TaxInformationDTO("US", "12345", "123456789", "123 Example St.", null, "Example Town", "NY");
subscriberService.GetTaxInformation(provider).Returns(taxInformation);
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var openInvoice = new Invoice var expected = new Customer
{ {
Id = "invoice_id", Id = "customer_id",
Status = "open", Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
DueDate = new DateTime(2024, 6, 1),
Created = new DateTime(2024, 5, 1),
PeriodEnd = new DateTime(2024, 6, 1)
}; };
stripeAdapter.InvoiceSearchAsync(Arg.Is<InvoiceSearchOptions>(options => stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
options.Query == $"subscription:'{subscription.Id}' status:'open'")) o.Address.Country == taxInfo.BillingAddressCountry &&
.Returns([openInvoice]); o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Email == provider.BillingEmail &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber))
.Returns(expected);
var (gotProviderPlans, gotSubscription, gotTaxInformation, gotSuspension) = await sutProvider.Sut.GetConsolidatedBillingSubscription(provider); var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo);
Assert.Equal(2, gotProviderPlans.Count); Assert.Equivalent(expected, actual);
Assert.Equivalent(subscription, gotSubscription);
Assert.Equivalent(taxInformation, gotTaxInformation);
Assert.NotNull(gotSuspension);
Assert.Equal(openInvoice.DueDate.Value.AddDays(30), gotSuspension.SuspensionDate);
Assert.Equal(openInvoice.PeriodEnd, gotSuspension.UnpaidPeriodEndDate);
Assert.Equal(30, gotSuspension.GracePeriod);
} }
#endregion #endregion
#region StartSubscription #region SetupSubscription
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task StartSubscription_NullProvider_ThrowsArgumentNullException( public async Task SetupSubscription_NullProvider_ThrowsArgumentNullException(
SutProvider<ProviderBillingService> sutProvider) => SutProvider<ProviderBillingService> sutProvider) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.StartSubscription(null)); await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.SetupSubscription(null));
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task StartSubscription_NoProviderPlans_ContactSupport( public async Task SetupSubscription_NoProviderPlans_ContactSupport(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider) Provider provider)
{ {
@ -1041,7 +843,7 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id) sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id)
.Returns(new List<ProviderPlan>()); .Returns(new List<ProviderPlan>());
await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider));
await sutProvider.GetDependency<IStripeAdapter>() await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceiveWithAnyArgs() .DidNotReceiveWithAnyArgs()
@ -1049,7 +851,7 @@ public class ProviderBillingServiceTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task StartSubscription_NoProviderTeamsPlan_ContactSupport( public async Task SetupSubscription_NoProviderTeamsPlan_ContactSupport(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider) Provider provider)
{ {
@ -1066,7 +868,7 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id) sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id)
.Returns(providerPlans); .Returns(providerPlans);
await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider));
await sutProvider.GetDependency<IStripeAdapter>() await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceiveWithAnyArgs() .DidNotReceiveWithAnyArgs()
@ -1074,7 +876,7 @@ public class ProviderBillingServiceTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task StartSubscription_NoProviderEnterprisePlan_ContactSupport( public async Task SetupSubscription_NoProviderEnterprisePlan_ContactSupport(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider) Provider provider)
{ {
@ -1091,7 +893,7 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id) sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id)
.Returns(providerPlans); .Returns(providerPlans);
await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider));
await sutProvider.GetDependency<IStripeAdapter>() await sutProvider.GetDependency<IStripeAdapter>()
.DidNotReceiveWithAnyArgs() .DidNotReceiveWithAnyArgs()
@ -1099,7 +901,7 @@ public class ProviderBillingServiceTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task StartSubscription_SubscriptionIncomplete_ThrowsBillingException( public async Task SetupSubscription_SubscriptionIncomplete_ThrowsBillingException(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider) Provider provider)
{ {
@ -1140,14 +942,11 @@ public class ProviderBillingServiceTests
.Returns( .Returns(
new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Incomplete }); new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Incomplete });
await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider));
await sutProvider.GetDependency<IProviderRepository>().Received(1)
.ReplaceAsync(Arg.Is<Provider>(p => p.GatewaySubscriptionId == "subscription_id"));
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task StartSubscription_Succeeds( public async Task SetupSubscription_Succeeds(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider) Provider provider)
{ {
@ -1187,6 +986,8 @@ public class ProviderBillingServiceTests
var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly);
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>( sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
sub => sub =>
sub.AutomaticTax.Enabled == true && sub.AutomaticTax.Enabled == true &&
@ -1200,16 +1001,11 @@ public class ProviderBillingServiceTests
sub.Items.ElementAt(1).Quantity == 100 && sub.Items.ElementAt(1).Quantity == 100 &&
sub.Metadata["providerId"] == provider.Id.ToString() && sub.Metadata["providerId"] == provider.Id.ToString() &&
sub.OffSession == true && sub.OffSession == true &&
sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)).Returns(new Subscription sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)).Returns(expected);
{
Id = "subscription_id",
Status = StripeConstants.SubscriptionStatus.Active
});
await sutProvider.Sut.StartSubscription(provider); var actual = await sutProvider.Sut.SetupSubscription(provider);
await sutProvider.GetDependency<IProviderRepository>().Received(1) Assert.Equivalent(expected, actual);
.ReplaceAsync(Arg.Is<Provider>(p => p.GatewaySubscriptionId == "subscription_id"));
} }
#endregion #endregion

View File

@ -1,9 +1,7 @@
using Bit.Api.AdminConsole.Models.Request.Providers; using Bit.Api.AdminConsole.Models.Request.Providers;
using Bit.Api.AdminConsole.Models.Response.Providers; using Bit.Api.AdminConsole.Models.Response.Providers;
using Bit.Core;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Services;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
@ -23,23 +21,15 @@ public class ProvidersController : Controller
private readonly IProviderService _providerService; private readonly IProviderService _providerService;
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings; private readonly GlobalSettings _globalSettings;
private readonly IFeatureService _featureService;
private readonly ILogger<ProvidersController> _logger;
private readonly IProviderBillingService _providerBillingService;
public ProvidersController(IUserService userService, IProviderRepository providerRepository, public ProvidersController(IUserService userService, IProviderRepository providerRepository,
IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings, IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings)
IFeatureService featureService, ILogger<ProvidersController> logger,
IProviderBillingService providerBillingService)
{ {
_userService = userService; _userService = userService;
_providerRepository = providerRepository; _providerRepository = providerRepository;
_providerService = providerService; _providerService = providerService;
_currentContext = currentContext; _currentContext = currentContext;
_globalSettings = globalSettings; _globalSettings = globalSettings;
_featureService = featureService;
_logger = logger;
_providerBillingService = providerBillingService;
} }
[HttpGet("{id:guid}")] [HttpGet("{id:guid}")]
@ -94,12 +84,8 @@ public class ProvidersController : Controller
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var response = var taxInfo = model.TaxInfo != null
await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key); ? new TaxInfo
if (_featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
{
var taxInfo = new TaxInfo
{ {
BillingAddressCountry = model.TaxInfo.Country, BillingAddressCountry = model.TaxInfo.Country,
BillingAddressPostalCode = model.TaxInfo.PostalCode, BillingAddressPostalCode = model.TaxInfo.PostalCode,
@ -108,20 +94,12 @@ public class ProvidersController : Controller
BillingAddressLine2 = model.TaxInfo.Line2, BillingAddressLine2 = model.TaxInfo.Line2,
BillingAddressCity = model.TaxInfo.City, BillingAddressCity = model.TaxInfo.City,
BillingAddressState = model.TaxInfo.State BillingAddressState = model.TaxInfo.State
};
try
{
await _providerBillingService.CreateCustomer(provider, taxInfo);
await _providerBillingService.StartSubscription(provider);
} }
catch : null;
{
// We don't want to trap the user on the setup page, so we'll let this go through but the provider will be in an un-billable state. var response =
_logger.LogError("Failed to create subscription for provider with ID {ID} during setup", provider.Id); await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key,
} taxInfo);
}
return new ProviderResponseModel(response); return new ProviderResponseModel(response);
} }

View File

@ -3,7 +3,9 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Models.Api;
using Bit.Core.Services; using Bit.Core.Services;
using Microsoft.AspNetCore.Http.HttpResults;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
namespace Bit.Api.Billing.Controllers; namespace Bit.Api.Billing.Controllers;
@ -11,8 +13,25 @@ namespace Bit.Api.Billing.Controllers;
public abstract class BaseProviderController( public abstract class BaseProviderController(
ICurrentContext currentContext, ICurrentContext currentContext,
IFeatureService featureService, IFeatureService featureService,
IProviderRepository providerRepository) : Controller ILogger<BaseProviderController> logger,
IProviderRepository providerRepository,
IUserService userService) : Controller
{ {
protected readonly IUserService UserService = userService;
protected static NotFound<ErrorResponseModel> NotFoundResponse() =>
TypedResults.NotFound(new ErrorResponseModel("Resource not found."));
protected static JsonHttpResult<ErrorResponseModel> ServerErrorResponse(string errorMessage) =>
TypedResults.Json(
new ErrorResponseModel(errorMessage),
statusCode: StatusCodes.Status500InternalServerError);
protected static JsonHttpResult<ErrorResponseModel> UnauthorizedResponse() =>
TypedResults.Json(
new ErrorResponseModel("Unauthorized."),
statusCode: StatusCodes.Status401Unauthorized);
protected Task<(Provider, IResult)> TryGetBillableProviderForAdminOperation( protected Task<(Provider, IResult)> TryGetBillableProviderForAdminOperation(
Guid providerId) => TryGetBillableProviderAsync(providerId, currentContext.ProviderProviderAdmin); Guid providerId) => TryGetBillableProviderAsync(providerId, currentContext.ProviderProviderAdmin);
@ -25,26 +44,53 @@ public abstract class BaseProviderController(
{ {
if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
{ {
return (null, TypedResults.NotFound()); logger.LogError(
"Cannot run Consolidated Billing operation for provider ({ProviderID}) while feature flag is disabled",
providerId);
return (null, NotFoundResponse());
} }
var provider = await providerRepository.GetByIdAsync(providerId); var provider = await providerRepository.GetByIdAsync(providerId);
if (provider == null) if (provider == null)
{ {
return (null, TypedResults.NotFound()); logger.LogError(
"Cannot find provider ({ProviderID}) for Consolidated Billing operation",
providerId);
return (null, NotFoundResponse());
} }
if (!checkAuthorization(providerId)) if (!checkAuthorization(providerId))
{ {
return (null, TypedResults.Unauthorized()); var user = await UserService.GetUserByPrincipalAsync(User);
logger.LogError(
"User ({UserID}) is not authorized to perform Consolidated Billing operation for provider ({ProviderID})",
user?.Id, providerId);
return (null, UnauthorizedResponse());
} }
if (!provider.IsBillable()) if (!provider.IsBillable())
{ {
return (null, TypedResults.Unauthorized()); logger.LogError(
"Cannot run Consolidated Billing operation for provider ({ProviderID}) that is not billable",
providerId);
return (null, UnauthorizedResponse());
} }
return (provider, null); if (provider.IsStripeEnabled())
{
return (provider, null);
}
logger.LogError(
"Cannot run Consolidated Billing operation for provider ({ProviderID}) that is missing Stripe configuration",
providerId);
return (null, ServerErrorResponse("Something went wrong with your request. Please contact support."));
} }
} }

View File

@ -1,15 +1,19 @@
using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses; using Bit.Api.Billing.Models.Responses;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Models.Api;
using Bit.Core.Models.BitStripe;
using Bit.Core.Services; using Bit.Core.Services;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Stripe; using Stripe;
using static Bit.Core.Billing.Utilities;
namespace Bit.Api.Billing.Controllers; namespace Bit.Api.Billing.Controllers;
[Route("providers/{providerId:guid}/billing")] [Route("providers/{providerId:guid}/billing")]
@ -17,10 +21,13 @@ namespace Bit.Api.Billing.Controllers;
public class ProviderBillingController( public class ProviderBillingController(
ICurrentContext currentContext, ICurrentContext currentContext,
IFeatureService featureService, IFeatureService featureService,
ILogger<BaseProviderController> logger,
IProviderBillingService providerBillingService, IProviderBillingService providerBillingService,
IProviderPlanRepository providerPlanRepository,
IProviderRepository providerRepository, IProviderRepository providerRepository,
ISubscriberService subscriberService,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService) : BaseProviderController(currentContext, featureService, providerRepository) IUserService userService) : BaseProviderController(currentContext, featureService, logger, providerRepository, userService)
{ {
[HttpGet("invoices")] [HttpGet("invoices")]
public async Task<IResult> GetInvoicesAsync([FromRoute] Guid providerId) public async Task<IResult> GetInvoicesAsync([FromRoute] Guid providerId)
@ -32,7 +39,10 @@ public class ProviderBillingController(
return result; return result;
} }
var invoices = await subscriberService.GetInvoices(provider); var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions
{
Customer = provider.GatewayCustomerId
});
var response = InvoicesResponse.From(invoices); var response = InvoicesResponse.From(invoices);
@ -53,7 +63,7 @@ public class ProviderBillingController(
if (reportContent == null) if (reportContent == null)
{ {
return TypedResults.NotFound(); return ServerErrorResponse("We had a problem generating your invoice CSV. Please contact support.");
} }
return TypedResults.File( return TypedResults.File(
@ -61,95 +71,6 @@ public class ProviderBillingController(
"text/csv"); "text/csv");
} }
[HttpGet("payment-information")]
public async Task<IResult> GetPaymentInformationAsync([FromRoute] Guid providerId)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var paymentInformation = await subscriberService.GetPaymentInformation(provider);
if (paymentInformation == null)
{
return TypedResults.NotFound();
}
var response = PaymentInformationResponse.From(paymentInformation);
return TypedResults.Ok(response);
}
[HttpGet("payment-method")]
public async Task<IResult> GetPaymentMethodAsync([FromRoute] Guid providerId)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var maskedPaymentMethod = await subscriberService.GetPaymentMethod(provider);
if (maskedPaymentMethod == null)
{
return TypedResults.NotFound();
}
var response = MaskedPaymentMethodResponse.From(maskedPaymentMethod);
return TypedResults.Ok(response);
}
[HttpPut("payment-method")]
public async Task<IResult> UpdatePaymentMethodAsync(
[FromRoute] Guid providerId,
[FromBody] TokenizedPaymentMethodRequestBody requestBody)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var tokenizedPaymentMethod = new TokenizedPaymentMethodDTO(
requestBody.Type,
requestBody.Token);
await subscriberService.UpdatePaymentMethod(provider, tokenizedPaymentMethod);
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
new SubscriptionUpdateOptions
{
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically
});
return TypedResults.Ok();
}
[HttpPost]
[Route("payment-method/verify-bank-account")]
public async Task<IResult> VerifyBankAccountAsync(
[FromRoute] Guid providerId,
[FromBody] VerifyBankAccountRequestBody requestBody)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
await subscriberService.VerifyBankAccount(provider, (requestBody.Amount1, requestBody.Amount2));
return TypedResults.Ok();
}
[HttpGet("subscription")] [HttpGet("subscription")]
public async Task<IResult> GetSubscriptionAsync([FromRoute] Guid providerId) public async Task<IResult> GetSubscriptionAsync([FromRoute] Guid providerId)
{ {
@ -160,36 +81,20 @@ public class ProviderBillingController(
return result; return result;
} }
var consolidatedBillingSubscription = await providerBillingService.GetConsolidatedBillingSubscription(provider); var subscription = await stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId,
new SubscriptionGetOptions { Expand = ["customer.tax_ids", "test_clock"] });
if (consolidatedBillingSubscription == null) var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
{
return TypedResults.NotFound();
}
var response = ConsolidatedBillingSubscriptionResponse.From(consolidatedBillingSubscription); var taxInformation = GetTaxInformation(subscription.Customer);
return TypedResults.Ok(response); var subscriptionSuspension = await GetSubscriptionSuspensionAsync(stripeAdapter, subscription);
}
[HttpGet("tax-information")] var response = ProviderSubscriptionResponse.From(
public async Task<IResult> GetTaxInformationAsync([FromRoute] Guid providerId) subscription,
{ providerPlans,
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); taxInformation,
subscriptionSuspension);
if (provider == null)
{
return result;
}
var taxInformation = await subscriberService.GetTaxInformation(provider);
if (taxInformation == null)
{
return TypedResults.NotFound();
}
var response = TaxInformationResponse.From(taxInformation);
return TypedResults.Ok(response); return TypedResults.Ok(response);
} }
@ -206,7 +111,13 @@ public class ProviderBillingController(
return result; return result;
} }
var taxInformation = new TaxInformationDTO( if (requestBody is not { Country: not null, PostalCode: not null })
{
return TypedResults.BadRequest(
new ErrorResponseModel("Country and postal code are required to update your tax information."));
}
var taxInformation = new TaxInformation(
requestBody.Country, requestBody.Country,
requestBody.PostalCode, requestBody.PostalCode,
requestBody.TaxId, requestBody.TaxId,

View File

@ -15,13 +15,13 @@ namespace Bit.Api.Billing.Controllers;
public class ProviderClientsController( public class ProviderClientsController(
ICurrentContext currentContext, ICurrentContext currentContext,
IFeatureService featureService, IFeatureService featureService,
ILogger<ProviderClientsController> logger, ILogger<BaseProviderController> logger,
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
IProviderBillingService providerBillingService, IProviderBillingService providerBillingService,
IProviderOrganizationRepository providerOrganizationRepository, IProviderOrganizationRepository providerOrganizationRepository,
IProviderRepository providerRepository, IProviderRepository providerRepository,
IProviderService providerService, IProviderService providerService,
IUserService userService) : BaseProviderController(currentContext, featureService, providerRepository) IUserService userService) : BaseProviderController(currentContext, featureService, logger, providerRepository, userService)
{ {
[HttpPost] [HttpPost]
public async Task<IResult> CreateAsync( public async Task<IResult> CreateAsync(
@ -35,11 +35,11 @@ public class ProviderClientsController(
return result; return result;
} }
var user = await userService.GetUserByPrincipalAsync(User); var user = await UserService.GetUserByPrincipalAsync(User);
if (user == null) if (user == null)
{ {
return TypedResults.Unauthorized(); return UnauthorizedResponse();
} }
var organizationSignup = new OrganizationSignup var organizationSignup = new OrganizationSignup
@ -63,13 +63,6 @@ public class ProviderClientsController(
var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId); var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId);
if (clientOrganization == null)
{
logger.LogError("Newly created client organization ({ID}) could not be found", providerOrganization.OrganizationId);
return TypedResults.Problem();
}
await providerBillingService.ScaleSeats( await providerBillingService.ScaleSeats(
provider, provider,
requestBody.PlanType, requestBody.PlanType,
@ -103,18 +96,11 @@ public class ProviderClientsController(
if (providerOrganization == null) if (providerOrganization == null)
{ {
return TypedResults.NotFound(); return NotFoundResponse();
} }
var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId); var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId);
if (clientOrganization == null)
{
logger.LogError("The client organization ({OrganizationID}) represented by provider organization ({ProviderOrganizationID}) could not be found.", providerOrganization.OrganizationId, providerOrganization.Id);
return TypedResults.Problem();
}
if (clientOrganization.Seats != requestBody.AssignedSeats) if (clientOrganization.Seats != requestBody.AssignedSeats)
{ {
await providerBillingService.AssignSeatsToClientOrganization( await providerBillingService.AssignSeatsToClientOrganization(

View File

@ -3,16 +3,16 @@
namespace Bit.Api.Billing.Models.Responses; namespace Bit.Api.Billing.Models.Responses;
public record InvoicesResponse( public record InvoicesResponse(
List<InvoiceDTO> Invoices) List<InvoiceResponse> Invoices)
{ {
public static InvoicesResponse From(IEnumerable<Invoice> invoices) => new( public static InvoicesResponse From(IEnumerable<Invoice> invoices) => new(
invoices invoices
.Where(i => i.Status is "open" or "paid" or "uncollectible") .Where(i => i.Status is "open" or "paid" or "uncollectible")
.OrderByDescending(i => i.Created) .OrderByDescending(i => i.Created)
.Select(InvoiceDTO.From).ToList()); .Select(InvoiceResponse.From).ToList());
} }
public record InvoiceDTO( public record InvoiceResponse(
string Id, string Id,
DateTime Date, DateTime Date,
string Number, string Number,
@ -21,7 +21,7 @@ public record InvoiceDTO(
DateTime? DueDate, DateTime? DueDate,
string Url) string Url)
{ {
public static InvoiceDTO From(Invoice invoice) => new( public static InvoiceResponse From(Invoice invoice) => new(
invoice.Id, invoice.Id,
invoice.Created, invoice.Created,
invoice.Number, invoice.Number,

View File

@ -5,7 +5,7 @@ namespace Bit.Api.Billing.Models.Responses;
public record PaymentInformationResponse( public record PaymentInformationResponse(
long AccountCredit, long AccountCredit,
MaskedPaymentMethodDTO PaymentMethod, MaskedPaymentMethodDTO PaymentMethod,
TaxInformationDTO TaxInformation) TaxInformation TaxInformation)
{ {
public static PaymentInformationResponse From(PaymentInformationDTO paymentInformation) => public static PaymentInformationResponse From(PaymentInformationDTO paymentInformation) =>
new( new(

View File

@ -1,43 +1,48 @@
using Bit.Core.Billing.Models; using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Models;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Stripe;
namespace Bit.Api.Billing.Models.Responses; namespace Bit.Api.Billing.Models.Responses;
public record ConsolidatedBillingSubscriptionResponse( public record ProviderSubscriptionResponse(
string Status, string Status,
DateTime CurrentPeriodEndDate, DateTime CurrentPeriodEndDate,
decimal? DiscountPercentage, decimal? DiscountPercentage,
string CollectionMethod, string CollectionMethod,
IEnumerable<ProviderPlanResponse> Plans, IEnumerable<ProviderPlanResponse> Plans,
long AccountCredit, long AccountCredit,
TaxInformationDTO TaxInformation, TaxInformation TaxInformation,
DateTime? CancelAt, DateTime? CancelAt,
SubscriptionSuspensionDTO Suspension) SubscriptionSuspension Suspension)
{ {
private const string _annualCadence = "Annual"; private const string _annualCadence = "Annual";
private const string _monthlyCadence = "Monthly"; private const string _monthlyCadence = "Monthly";
public static ConsolidatedBillingSubscriptionResponse From( public static ProviderSubscriptionResponse From(
ConsolidatedBillingSubscriptionDTO consolidatedBillingSubscription) Subscription subscription,
ICollection<ProviderPlan> providerPlans,
TaxInformation taxInformation,
SubscriptionSuspension subscriptionSuspension)
{ {
var (providerPlans, subscription, taxInformation, suspension) = consolidatedBillingSubscription;
var providerPlanResponses = providerPlans var providerPlanResponses = providerPlans
.Select(providerPlan => .Where(providerPlan => providerPlan.IsConfigured())
.Select(ConfiguredProviderPlan.From)
.Select(configuredProviderPlan =>
{ {
var plan = StaticStore.GetPlan(providerPlan.PlanType); var plan = StaticStore.GetPlan(configuredProviderPlan.PlanType);
var cost = (providerPlan.SeatMinimum + providerPlan.PurchasedSeats) * plan.PasswordManager.ProviderPortalSeatPrice; var cost = (configuredProviderPlan.SeatMinimum + configuredProviderPlan.PurchasedSeats) * plan.PasswordManager.ProviderPortalSeatPrice;
var cadence = plan.IsAnnual ? _annualCadence : _monthlyCadence; var cadence = plan.IsAnnual ? _annualCadence : _monthlyCadence;
return new ProviderPlanResponse( return new ProviderPlanResponse(
plan.Name, plan.Name,
providerPlan.SeatMinimum, configuredProviderPlan.SeatMinimum,
providerPlan.PurchasedSeats, configuredProviderPlan.PurchasedSeats,
providerPlan.AssignedSeats, configuredProviderPlan.AssignedSeats,
cost, cost,
cadence); cadence);
}); });
return new ConsolidatedBillingSubscriptionResponse( return new ProviderSubscriptionResponse(
subscription.Status, subscription.Status,
subscription.CurrentPeriodEnd, subscription.CurrentPeriodEnd,
subscription.Customer?.Discount?.Coupon?.PercentOff, subscription.Customer?.Discount?.Coupon?.PercentOff,
@ -46,7 +51,7 @@ public record ConsolidatedBillingSubscriptionResponse(
subscription.Customer?.Balance ?? 0, subscription.Customer?.Balance ?? 0,
taxInformation, taxInformation,
subscription.CancelAt, subscription.CancelAt,
suspension); subscriptionSuspension);
} }
} }

View File

@ -11,7 +11,7 @@ public record TaxInformationResponse(
string City, string City,
string State) string State)
{ {
public static TaxInformationResponse From(TaxInformationDTO taxInformation) public static TaxInformationResponse From(TaxInformation taxInformation)
=> new( => new(
taxInformation.Country, taxInformation.Country,
taxInformation.PostalCode, taxInformation.PostalCode,

View File

@ -1,4 +1,6 @@
using Bit.Api.Models.Public.Response; using System.Text;
using Bit.Api.Models.Public.Response;
using Bit.Core.Billing;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.AspNetCore.Mvc.Filters;
@ -49,18 +51,18 @@ public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute
errorMessage = badRequestException.Message; errorMessage = badRequestException.Message;
} }
} }
else if (exception is StripeException stripeException && stripeException?.StripeError?.Type == "card_error") else if (exception is StripeException { StripeError.Type: "card_error" } stripeCardErrorException)
{ {
context.HttpContext.Response.StatusCode = 400; context.HttpContext.Response.StatusCode = 400;
if (_publicApi) if (_publicApi)
{ {
publicErrorModel = new ErrorResponseModel(stripeException.StripeError.Param, publicErrorModel = new ErrorResponseModel(stripeCardErrorException.StripeError.Param,
stripeException.Message); stripeCardErrorException.Message);
} }
else else
{ {
internalErrorModel = new InternalApi.ErrorResponseModel(stripeException.StripeError.Param, internalErrorModel = new InternalApi.ErrorResponseModel(stripeCardErrorException.StripeError.Param,
stripeException.Message); stripeCardErrorException.Message);
} }
} }
else if (exception is GatewayException) else if (exception is GatewayException)
@ -68,6 +70,40 @@ public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute
errorMessage = exception.Message; errorMessage = exception.Message;
context.HttpContext.Response.StatusCode = 400; context.HttpContext.Response.StatusCode = 400;
} }
else if (exception is BillingException billingException)
{
errorMessage = billingException.Response;
context.HttpContext.Response.StatusCode = StatusCodes.Status500InternalServerError;
}
else if (exception is StripeException stripeException)
{
var logger = context.HttpContext.RequestServices.GetRequiredService<ILogger<ExceptionHandlerFilterAttribute>>();
var error = stripeException.Message;
if (stripeException.StripeError != null)
{
var stringBuilder = new StringBuilder();
if (!string.IsNullOrEmpty(stripeException.StripeError.Code))
{
stringBuilder.Append($"{stripeException.StripeError.Code} | ");
}
stringBuilder.Append(stripeException.StripeError.Message);
if (!string.IsNullOrEmpty(stripeException.StripeError.DocUrl))
{
stringBuilder.Append($" > {stripeException.StripeError.DocUrl}");
}
error = stringBuilder.ToString();
}
logger.LogError("An unhandled error occurred while communicating with Stripe: {Error}", error);
errorMessage = "Something went wrong with your request. Please contact support.";
context.HttpContext.Response.StatusCode = StatusCodes.Status500InternalServerError;
}
else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message))
{ {
errorMessage = exception.Message; errorMessage = exception.Message;

View File

@ -7,7 +7,7 @@ namespace Bit.Core.AdminConsole.Services;
public interface IProviderService public interface IProviderService
{ {
Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key); Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null);
Task UpdateAsync(Provider provider, bool updateBilling = false); Task UpdateAsync(Provider provider, bool updateBilling = false);
Task<List<ProviderUser>> InviteUserAsync(ProviderUserInvite<string> invite); Task<List<ProviderUser>> InviteUserAsync(ProviderUserInvite<string> invite);

View File

@ -7,7 +7,7 @@ namespace Bit.Core.AdminConsole.Services.NoopImplementations;
public class NoopProviderService : IProviderService public class NoopProviderService : IProviderService
{ {
public Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) => throw new NotImplementedException(); public Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null) => throw new NotImplementedException();
public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException();

View File

@ -1,9 +1,9 @@
namespace Bit.Core.Billing; namespace Bit.Core.Billing;
public class BillingException( public class BillingException(
string clientFriendlyMessage, string response = null,
string internalMessage = null, string message = null,
Exception innerException = null) : Exception(internalMessage, innerException) Exception innerException = null) : Exception(message, innerException)
{ {
public string ClientFriendlyMessage { get; set; } = clientFriendlyMessage; public string Response { get; } = response ?? "Something went wrong with your request. Please contact support.";
} }

View File

@ -21,6 +21,11 @@ public static class StripeConstants
public const string SecretsManagerStandalone = "sm-standalone"; public const string SecretsManagerStandalone = "sm-standalone";
} }
public static class ErrorCodes
{
public const string CustomerTaxLocationInvalid = "customer_tax_location_invalid";
}
public static class PaymentMethodTypes public static class PaymentMethodTypes
{ {
public const string Card = "card"; public const string Card = "card";

View File

@ -2,6 +2,7 @@
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Stripe; using Stripe;
@ -24,9 +25,9 @@ public static class BillingExtensions
PlanType: PlanType.TeamsMonthly or PlanType.EnterpriseMonthly PlanType: PlanType.TeamsMonthly or PlanType.EnterpriseMonthly
}; };
public static bool IsStripeEnabled(this Organization organization) public static bool IsStripeEnabled(this ISubscriber subscriber)
=> !string.IsNullOrEmpty(organization.GatewayCustomerId) && => !string.IsNullOrEmpty(subscriber.GatewayCustomerId) &&
!string.IsNullOrEmpty(organization.GatewaySubscriptionId); !string.IsNullOrEmpty(subscriber.GatewaySubscriptionId);
public static bool IsUnverifiedBankAccount(this SetupIntent setupIntent) => public static bool IsUnverifiedBankAccount(this SetupIntent setupIntent) =>
setupIntent is setupIntent is

View File

@ -3,7 +3,7 @@ using Bit.Core.Billing.Enums;
namespace Bit.Core.Billing.Models; namespace Bit.Core.Billing.Models;
public record ConfiguredProviderPlanDTO( public record ConfiguredProviderPlan(
Guid Id, Guid Id,
Guid ProviderId, Guid ProviderId,
PlanType PlanType, PlanType PlanType,
@ -11,9 +11,9 @@ public record ConfiguredProviderPlanDTO(
int PurchasedSeats, int PurchasedSeats,
int AssignedSeats) int AssignedSeats)
{ {
public static ConfiguredProviderPlanDTO From(ProviderPlan providerPlan) => public static ConfiguredProviderPlan From(ProviderPlan providerPlan) =>
providerPlan.IsConfigured() providerPlan.IsConfigured()
? new ConfiguredProviderPlanDTO( ? new ConfiguredProviderPlan(
providerPlan.Id, providerPlan.Id,
providerPlan.ProviderId, providerPlan.ProviderId,
providerPlan.PlanType, providerPlan.PlanType,

View File

@ -1,9 +0,0 @@
using Stripe;
namespace Bit.Core.Billing.Models;
public record ConsolidatedBillingSubscriptionDTO(
List<ConfiguredProviderPlanDTO> ProviderPlans,
Subscription Subscription,
TaxInformationDTO TaxInformation,
SubscriptionSuspensionDTO Suspension);

View File

@ -3,4 +3,4 @@
public record PaymentInformationDTO( public record PaymentInformationDTO(
long AccountCredit, long AccountCredit,
MaskedPaymentMethodDTO PaymentMethod, MaskedPaymentMethodDTO PaymentMethod,
TaxInformationDTO TaxInformation); TaxInformation TaxInformation);

View File

@ -1,6 +1,6 @@
namespace Bit.Core.Billing.Models; namespace Bit.Core.Billing.Models;
public record SubscriptionSuspensionDTO( public record SubscriptionSuspension(
DateTime SuspensionDate, DateTime SuspensionDate,
DateTime UnpaidPeriodEndDate, DateTime UnpaidPeriodEndDate,
int GracePeriod); int GracePeriod);

View File

@ -1,6 +1,6 @@
namespace Bit.Core.Billing.Models; namespace Bit.Core.Billing.Models;
public record TaxInformationDTO( public record TaxInformation(
string Country, string Country,
string PostalCode, string PostalCode,
string TaxId, string TaxId,

View File

@ -3,8 +3,8 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing.Entities; using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
using Stripe;
namespace Bit.Core.Billing.Services; namespace Bit.Core.Billing.Services;
@ -24,16 +24,6 @@ public interface IProviderBillingService
Organization organization, Organization organization,
int seats); int seats);
/// <summary>
/// Create a Stripe <see cref="Stripe.Customer"/> for the specified <paramref name="provider"/> utilizing the provided <paramref name="taxInfo"/>.
/// </summary>
/// <param name="provider">The <see cref="Provider"/> to create a Stripe customer for.</param>
/// <param name="taxInfo">The <see cref="TaxInfo"/> to use for calculating the customer's automatic tax.</param>
/// <returns></returns>
Task CreateCustomer(
Provider provider,
TaxInfo taxInfo);
/// <summary> /// <summary>
/// Create a Stripe <see cref="Stripe.Customer"/> for the provided client <paramref name="organization"/> utilizing /// Create a Stripe <see cref="Stripe.Customer"/> for the provided client <paramref name="organization"/> utilizing
/// the address and tax information of its <paramref name="provider"/>. /// the address and tax information of its <paramref name="provider"/>.
@ -65,15 +55,6 @@ public interface IProviderBillingService
Guid providerId, Guid providerId,
PlanType planType); PlanType planType);
/// <summary>
/// Retrieves the <paramref name="provider"/>'s consolidated billing subscription, which includes their Stripe subscription and configured provider plans.
/// </summary>
/// <param name="provider">The provider to retrieve the consolidated billing subscription for.</param>
/// <returns>A <see cref="ConsolidatedBillingSubscriptionDTO"/> containing the provider's Stripe <see cref="Stripe.Subscription"/> and a list of <see cref="ConfiguredProviderPlanDTO"/>s representing their configured plans.</returns>
/// <remarks>This method opts for returning <see langword="null"/> rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks>
Task<ConsolidatedBillingSubscriptionDTO> GetConsolidatedBillingSubscription(
Provider provider);
/// <summary> /// <summary>
/// Scales the <paramref name="provider"/>'s seats for the specified <paramref name="planType"/> using the provided <paramref name="seatAdjustment"/>. /// Scales the <paramref name="provider"/>'s seats for the specified <paramref name="planType"/> using the provided <paramref name="seatAdjustment"/>.
/// This operation may autoscale the provider's Stripe <see cref="Stripe.Subscription"/> depending on the <paramref name="provider"/>'s seat minimum for the /// This operation may autoscale the provider's Stripe <see cref="Stripe.Subscription"/> depending on the <paramref name="provider"/>'s seat minimum for the
@ -88,11 +69,23 @@ public interface IProviderBillingService
int seatAdjustment); int seatAdjustment);
/// <summary> /// <summary>
/// Starts a Stripe <see cref="Stripe.Subscription"/> for the given <paramref name="provider"/> given it has an existing Stripe <see cref="Stripe.Customer"/>. /// For use during the provider setup process, this method creates a Stripe <see cref="Stripe.Customer"/> for the specified <paramref name="provider"/> utilizing the provided <paramref name="taxInfo"/>.
/// </summary>
/// <param name="provider">The <see cref="Provider"/> to create a Stripe customer for.</param>
/// <param name="taxInfo">The <see cref="TaxInfo"/> to use for calculating the customer's automatic tax.</param>
/// <returns>The newly created <see cref="Stripe.Customer"/> for the <paramref name="provider"/>.</returns>
Task<Customer> SetupCustomer(
Provider provider,
TaxInfo taxInfo);
/// <summary>
/// For use during the provider setup process, this method starts a Stripe <see cref="Stripe.Subscription"/> for the given <paramref name="provider"/>.
/// <see cref="Provider"/> subscriptions will always be started with a <see cref="Stripe.SubscriptionItem"/> for both the <see cref="PlanType.TeamsMonthly"/> /// <see cref="Provider"/> subscriptions will always be started with a <see cref="Stripe.SubscriptionItem"/> for both the <see cref="PlanType.TeamsMonthly"/>
/// and <see cref="PlanType.EnterpriseMonthly"/> plan, and the quantity for each item will be equal the provider's seat minimum for each respective plan. /// and <see cref="PlanType.EnterpriseMonthly"/> plan, and the quantity for each item will be equal the provider's seat minimum for each respective plan.
/// </summary> /// </summary>
/// <param name="provider">The provider to create the <see cref="Stripe.Subscription"/> for.</param> /// <param name="provider">The provider to create the <see cref="Stripe.Subscription"/> for.</param>
Task StartSubscription( /// <returns>The newly created <see cref="Stripe.Subscription"/> for the <paramref name="provider"/>.</returns>
/// <remarks>This method requires the <paramref name="provider"/> to already have a linked Stripe <see cref="Stripe.Customer"/> via its <see cref="Provider.GatewayCustomerId"/> field.</remarks>
Task<Subscription> SetupSubscription(
Provider provider); Provider provider);
} }

View File

@ -1,7 +1,6 @@
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models.BitStripe;
using Stripe; using Stripe;
namespace Bit.Core.Billing.Services; namespace Bit.Core.Billing.Services;
@ -47,18 +46,6 @@ public interface ISubscriberService
ISubscriber subscriber, ISubscriber subscriber,
CustomerGetOptions customerGetOptions = null); CustomerGetOptions customerGetOptions = null);
/// <summary>
/// Retrieves a list of Stripe <see cref="Invoice"/> objects using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewayCustomerId"/> property.
/// </summary>
/// <param name="subscriber">The subscriber to retrieve the Stripe invoices for.</param>
/// <param name="invoiceListOptions">Optional parameters that can be passed to Stripe to expand, modify or filter the invoices. The <see cref="subscriber"/>'s
/// <see cref="ISubscriber.GatewayCustomerId"/> will be automatically attached to the provided options as the <see cref="InvoiceListOptions.Customer"/> parameter.</param>
/// <returns>A list of Stripe <see cref="Invoice"/> objects.</returns>
/// <remarks>This method opts for returning an empty list rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks>
Task<List<Invoice>> GetInvoices(
ISubscriber subscriber,
StripeInvoiceListOptions invoiceListOptions = null);
/// <summary> /// <summary>
/// Retrieves the account credit, a masked representation of the default payment method and the tax information for the /// Retrieves the account credit, a masked representation of the default payment method and the tax information for the
/// provided <paramref name="subscriber"/>. This is essentially a consolidated invocation of the <see cref="GetPaymentMethod"/> /// provided <paramref name="subscriber"/>. This is essentially a consolidated invocation of the <see cref="GetPaymentMethod"/>
@ -106,10 +93,10 @@ public interface ISubscriberService
/// Retrieves the <see cref="subscriber"/>'s tax information using their Stripe <see cref="Stripe.Customer"/>'s <see cref="Stripe.Customer.Address"/>. /// Retrieves the <see cref="subscriber"/>'s tax information using their Stripe <see cref="Stripe.Customer"/>'s <see cref="Stripe.Customer.Address"/>.
/// </summary> /// </summary>
/// <param name="subscriber">The subscriber to retrieve the tax information for.</param> /// <param name="subscriber">The subscriber to retrieve the tax information for.</param>
/// <returns>A <see cref="TaxInformationDTO"/> representing the <paramref name="subscriber"/>'s tax information.</returns> /// <returns>A <see cref="TaxInformation"/> representing the <paramref name="subscriber"/>'s tax information.</returns>
/// <exception cref="ArgumentNullException">Thrown when the <paramref name="subscriber"/> is <see langword="null"/>.</exception> /// <exception cref="ArgumentNullException">Thrown when the <paramref name="subscriber"/> is <see langword="null"/>.</exception>
/// <remarks>This method opts for returning <see langword="null"/> rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks> /// <remarks>This method opts for returning <see langword="null"/> rather than throwing exceptions, making it ideal for surfacing data from API endpoints.</remarks>
Task<TaxInformationDTO> GetTaxInformation( Task<TaxInformation> GetTaxInformation(
ISubscriber subscriber); ISubscriber subscriber);
/// <summary> /// <summary>
@ -137,10 +124,10 @@ public interface ISubscriberService
/// Updates the tax information for the provided <paramref name="subscriber"/>. /// Updates the tax information for the provided <paramref name="subscriber"/>.
/// </summary> /// </summary>
/// <param name="subscriber">The <paramref name="subscriber"/> to update the tax information for.</param> /// <param name="subscriber">The <paramref name="subscriber"/> to update the tax information for.</param>
/// <param name="taxInformation">A <see cref="TaxInformationDTO"/> representing the <paramref name="subscriber"/>'s updated tax information.</param> /// <param name="taxInformation">A <see cref="TaxInformation"/> representing the <paramref name="subscriber"/>'s updated tax information.</param>
Task UpdateTaxInformation( Task UpdateTaxInformation(
ISubscriber subscriber, ISubscriber subscriber,
TaxInformationDTO taxInformation); TaxInformation taxInformation);
/// <summary> /// <summary>
/// Verifies the subscriber's pending bank account using the provided <paramref name="microdeposits"/>. /// Verifies the subscriber's pending bank account using the provided <paramref name="microdeposits"/>.

View File

@ -2,7 +2,6 @@
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models.BitStripe;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
@ -37,7 +36,7 @@ public class SubscriberService(
{ {
logger.LogWarning("Cannot cancel subscription ({ID}) that's already inactive", subscription.Id); logger.LogWarning("Cannot cancel subscription ({ID}) that's already inactive", subscription.Id);
throw ContactSupport(); throw new BillingException();
} }
var metadata = new Dictionary<string, string> var metadata = new Dictionary<string, string>
@ -148,7 +147,7 @@ public class SubscriberService(
{ {
logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId)); logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId));
throw ContactSupport(); throw new BillingException();
} }
try try
@ -163,48 +162,16 @@ public class SubscriberService(
logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})",
subscriber.GatewayCustomerId, subscriber.Id); subscriber.GatewayCustomerId, subscriber.Id);
throw ContactSupport(); throw new BillingException();
} }
catch (StripeException exception) catch (StripeException stripeException)
{ {
logger.LogError("An error occurred while trying to retrieve Stripe customer ({CustomerID}) for subscriber ({SubscriberID}): {Error}", logger.LogError("An error occurred while trying to retrieve Stripe customer ({CustomerID}) for subscriber ({SubscriberID}): {Error}",
subscriber.GatewayCustomerId, subscriber.Id, exception.Message); subscriber.GatewayCustomerId, subscriber.Id, stripeException.Message);
throw ContactSupport("An error occurred while trying to retrieve a Stripe Customer", exception); throw new BillingException(
} message: "An error occurred while trying to retrieve a Stripe customer",
} innerException: stripeException);
public async Task<List<Invoice>> GetInvoices(
ISubscriber subscriber,
StripeInvoiceListOptions invoiceListOptions = null)
{
ArgumentNullException.ThrowIfNull(subscriber);
if (string.IsNullOrEmpty(subscriber.GatewayCustomerId))
{
logger.LogError("Cannot retrieve invoices for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId));
return [];
}
try
{
if (invoiceListOptions == null)
{
invoiceListOptions = new StripeInvoiceListOptions { Customer = subscriber.GatewayCustomerId };
}
else
{
invoiceListOptions.Customer = subscriber.GatewayCustomerId;
}
return await stripeAdapter.InvoiceListAsync(invoiceListOptions);
}
catch (StripeException exception)
{
logger.LogError("An error occurred while trying to retrieve Stripe invoices for subscriber ({SubscriberID}): {Error}", subscriber.Id, exception.Message);
return [];
} }
} }
@ -294,7 +261,7 @@ public class SubscriberService(
{ {
logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId)); logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId));
throw ContactSupport(); throw new BillingException();
} }
try try
@ -309,18 +276,20 @@ public class SubscriberService(
logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})", logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})",
subscriber.GatewaySubscriptionId, subscriber.Id); subscriber.GatewaySubscriptionId, subscriber.Id);
throw ContactSupport(); throw new BillingException();
} }
catch (StripeException exception) catch (StripeException stripeException)
{ {
logger.LogError("An error occurred while trying to retrieve Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID}): {Error}", logger.LogError("An error occurred while trying to retrieve Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID}): {Error}",
subscriber.GatewaySubscriptionId, subscriber.Id, exception.Message); subscriber.GatewaySubscriptionId, subscriber.Id, stripeException.Message);
throw ContactSupport("An error occurred while trying to retrieve a Stripe Subscription", exception); throw new BillingException(
message: "An error occurred while trying to retrieve a Stripe subscription",
innerException: stripeException);
} }
} }
public async Task<TaxInformationDTO> GetTaxInformation( public async Task<TaxInformation> GetTaxInformation(
ISubscriber subscriber) ISubscriber subscriber)
{ {
ArgumentNullException.ThrowIfNull(subscriber); ArgumentNullException.ThrowIfNull(subscriber);
@ -337,7 +306,7 @@ public class SubscriberService(
if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) if (string.IsNullOrEmpty(subscriber.GatewayCustomerId))
{ {
throw ContactSupport(); throw new BillingException();
} }
var stripeCustomer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions var stripeCustomer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions
@ -353,7 +322,7 @@ public class SubscriberService(
{ {
logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId); logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId);
throw ContactSupport(); throw new BillingException();
} }
if (braintreeCustomer.DefaultPaymentMethod != null) if (braintreeCustomer.DefaultPaymentMethod != null)
@ -369,7 +338,7 @@ public class SubscriberService(
logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}", logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}",
braintreeCustomerId, updateCustomerResult.Message); braintreeCustomerId, updateCustomerResult.Message);
throw ContactSupport(); throw new BillingException();
} }
var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token); var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token);
@ -384,7 +353,7 @@ public class SubscriberService(
"Failed to delete Braintree payment method for Customer ({ID}), re-linked payment method. Message: {Message}", "Failed to delete Braintree payment method for Customer ({ID}), re-linked payment method. Message: {Message}",
braintreeCustomerId, deletePaymentMethodResult.Message); braintreeCustomerId, deletePaymentMethodResult.Message);
throw ContactSupport(); throw new BillingException();
} }
} }
else else
@ -437,7 +406,7 @@ public class SubscriberService(
{ {
logger.LogError("Updated payment method for ({SubscriberID}) must contain a token", subscriber.Id); logger.LogError("Updated payment method for ({SubscriberID}) must contain a token", subscriber.Id);
throw ContactSupport(); throw new BillingException();
} }
// ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault // ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault
@ -462,7 +431,7 @@ public class SubscriberService(
{ {
logger.LogError("There were more than 1 setup intents for subscriber's ({SubscriberID}) updated payment method", subscriber.Id); logger.LogError("There were more than 1 setup intents for subscriber's ({SubscriberID}) updated payment method", subscriber.Id);
throw ContactSupport(); throw new BillingException();
} }
var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First(); var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First();
@ -551,7 +520,7 @@ public class SubscriberService(
{ {
logger.LogError("Failed to retrieve Braintree customer ({BraintreeCustomerId}) when updating payment method for subscriber ({SubscriberID})", braintreeCustomerId, subscriber.Id); logger.LogError("Failed to retrieve Braintree customer ({BraintreeCustomerId}) when updating payment method for subscriber ({SubscriberID})", braintreeCustomerId, subscriber.Id);
throw ContactSupport(); throw new BillingException();
} }
await ReplaceBraintreePaymentMethodAsync(braintreeCustomer, token); await ReplaceBraintreePaymentMethodAsync(braintreeCustomer, token);
@ -570,14 +539,14 @@ public class SubscriberService(
{ {
logger.LogError("Cannot update subscriber's ({SubscriberID}) payment method to type ({PaymentMethodType}) as it is not supported", subscriber.Id, type.ToString()); logger.LogError("Cannot update subscriber's ({SubscriberID}) payment method to type ({PaymentMethodType}) as it is not supported", subscriber.Id, type.ToString());
throw ContactSupport(); throw new BillingException();
} }
} }
} }
public async Task UpdateTaxInformation( public async Task UpdateTaxInformation(
ISubscriber subscriber, ISubscriber subscriber,
TaxInformationDTO taxInformation) TaxInformation taxInformation)
{ {
ArgumentNullException.ThrowIfNull(subscriber); ArgumentNullException.ThrowIfNull(subscriber);
ArgumentNullException.ThrowIfNull(taxInformation); ArgumentNullException.ThrowIfNull(taxInformation);
@ -635,7 +604,7 @@ public class SubscriberService(
{ {
logger.LogError("No setup intent ID exists to verify for subscriber with ID ({SubscriberID})", subscriber.Id); logger.LogError("No setup intent ID exists to verify for subscriber with ID ({SubscriberID})", subscriber.Id);
throw ContactSupport(); throw new BillingException();
} }
var (amount1, amount2) = microdeposits; var (amount1, amount2) = microdeposits;
@ -706,7 +675,7 @@ public class SubscriberService(
logger.LogError("Failed to create Braintree customer for subscriber ({ID})", subscriber.Id); logger.LogError("Failed to create Braintree customer for subscriber ({ID})", subscriber.Id);
throw ContactSupport(); throw new BillingException();
} }
private async Task<MaskedPaymentMethodDTO> GetMaskedPaymentMethodDTOAsync( private async Task<MaskedPaymentMethodDTO> GetMaskedPaymentMethodDTOAsync(
@ -751,7 +720,7 @@ public class SubscriberService(
return MaskedPaymentMethodDTO.From(setupIntent); return MaskedPaymentMethodDTO.From(setupIntent);
} }
private static TaxInformationDTO GetTaxInformationDTOFrom( private static TaxInformation GetTaxInformationDTOFrom(
Customer customer) Customer customer)
{ {
if (customer.Address == null) if (customer.Address == null)
@ -759,7 +728,7 @@ public class SubscriberService(
return null; return null;
} }
return new TaxInformationDTO( return new TaxInformation(
customer.Address.Country, customer.Address.Country,
customer.Address.PostalCode, customer.Address.PostalCode,
customer.TaxIds?.FirstOrDefault()?.Value, customer.TaxIds?.FirstOrDefault()?.Value,
@ -825,7 +794,7 @@ public class SubscriberService(
{ {
logger.LogError("Failed to replace payment method for Braintree customer ({ID}) - Creation of new payment method failed | Error: {Error}", customer.Id, createPaymentMethodResult.Message); logger.LogError("Failed to replace payment method for Braintree customer ({ID}) - Creation of new payment method failed | Error: {Error}", customer.Id, createPaymentMethodResult.Message);
throw ContactSupport(); throw new BillingException();
} }
var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync( var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync(
@ -839,7 +808,7 @@ public class SubscriberService(
await braintreeGateway.PaymentMethod.DeleteAsync(createPaymentMethodResult.Target.Token); await braintreeGateway.PaymentMethod.DeleteAsync(createPaymentMethodResult.Target.Token);
throw ContactSupport(); throw new BillingException();
} }
if (existingDefaultPaymentMethod != null) if (existingDefaultPaymentMethod != null)

View File

@ -8,12 +8,7 @@ public static class Utilities
{ {
public const string BraintreeCustomerIdKey = "btCustomerId"; public const string BraintreeCustomerIdKey = "btCustomerId";
public static BillingException ContactSupport( public static async Task<SubscriptionSuspension> GetSubscriptionSuspensionAsync(
string internalMessage = null,
Exception innerException = null) => new("Something went wrong with your request. Please contact support.",
internalMessage, innerException);
public static async Task<SubscriptionSuspensionDTO> GetSuspensionAsync(
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
Subscription subscription) Subscription subscription)
{ {
@ -49,7 +44,7 @@ public static class Utilities
const int gracePeriod = 14; const int gracePeriod = 14;
return new SubscriptionSuspensionDTO( return new SubscriptionSuspension(
firstOverdueInvoice.Created.AddDays(gracePeriod), firstOverdueInvoice.Created.AddDays(gracePeriod),
firstOverdueInvoice.PeriodEnd, firstOverdueInvoice.PeriodEnd,
gracePeriod); gracePeriod);
@ -67,7 +62,7 @@ public static class Utilities
const int gracePeriod = 30; const int gracePeriod = 30;
return new SubscriptionSuspensionDTO( return new SubscriptionSuspension(
firstOverdueInvoice.DueDate.Value.AddDays(gracePeriod), firstOverdueInvoice.DueDate.Value.AddDays(gracePeriod),
firstOverdueInvoice.PeriodEnd, firstOverdueInvoice.PeriodEnd,
gracePeriod); gracePeriod);
@ -75,4 +70,21 @@ public static class Utilities
default: return null; default: return null;
} }
} }
public static TaxInformation GetTaxInformation(Customer customer)
{
if (customer.Address == null)
{
return null;
}
return new TaxInformation(
customer.Address.Country,
customer.Address.PostalCode,
customer.TaxIds?.FirstOrDefault()?.Value,
customer.Address.Line1,
customer.Address.Line2,
customer.Address.City,
customer.Address.State);
}
} }

View File

@ -1,9 +1,8 @@
using Bit.Core.Billing.Enums; using Bit.Core.Billing;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Stripe; using Stripe;
using static Bit.Core.Billing.Utilities;
namespace Bit.Core.Models.Business; namespace Bit.Core.Models.Business;
public class ProviderSubscriptionUpdate : SubscriptionUpdate public class ProviderSubscriptionUpdate : SubscriptionUpdate
@ -21,7 +20,8 @@ public class ProviderSubscriptionUpdate : SubscriptionUpdate
{ {
if (!planType.SupportsConsolidatedBilling()) if (!planType.SupportsConsolidatedBilling())
{ {
throw ContactSupport($"Cannot create a {nameof(ProviderSubscriptionUpdate)} for {nameof(PlanType)} that doesn't support consolidated billing"); throw new BillingException(
message: $"Cannot create a {nameof(ProviderSubscriptionUpdate)} for {nameof(PlanType)} that doesn't support consolidated billing");
} }
var plan = Utilities.StaticStore.GetPlan(planType); var plan = Utilities.StaticStore.GetPlan(planType);

View File

@ -6,15 +6,19 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Enums; using Bit.Core.Models.Api;
using Bit.Core.Models.BitStripe;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Http.HttpResults;
using NSubstitute; using NSubstitute;
using NSubstitute.ReturnsExtensions; using NSubstitute.ReturnsExtensions;
@ -29,7 +33,74 @@ namespace Bit.Api.Test.Billing.Controllers;
[SutProviderCustomize] [SutProviderCustomize]
public class ProviderBillingControllerTests public class ProviderBillingControllerTests
{ {
#region GetInvoicesAsync #region GetInvoicesAsync & TryGetBillableProviderForAdminOperations
[Theory, BitAutoData]
public async Task GetInvoicesAsync_FFDisabled_NotFound(
Guid providerId,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(false);
var result = await sutProvider.Sut.GetInvoicesAsync(providerId);
AssertNotFound(result);
}
[Theory, BitAutoData]
public async Task GetInvoicesAsync_NullProvider_NotFound(
Guid providerId,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId).ReturnsNull();
var result = await sutProvider.Sut.GetInvoicesAsync(providerId);
AssertNotFound(result);
}
[Theory, BitAutoData]
public async Task GetInvoicesAsync_NotProviderUser_Unauthorized(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(provider.Id)
.Returns(false);
var result = await sutProvider.Sut.GetInvoicesAsync(provider.Id);
AssertUnauthorized(result);
}
[Theory, BitAutoData]
public async Task GetInvoicesAsync_ProviderNotBillable_Unauthorized(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
provider.Type = ProviderType.Reseller;
provider.Status = ProviderStatusType.Created;
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(provider.Id)
.Returns(true);
var result = await sutProvider.Sut.GetInvoicesAsync(provider.Id);
AssertUnauthorized(result);
}
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetInvoices_Ok( public async Task GetInvoices_Ok(
@ -73,7 +144,9 @@ public class ProviderBillingControllerTests
} }
}; };
sutProvider.GetDependency<ISubscriberService>().GetInvoices(provider).Returns(invoices); sutProvider.GetDependency<IStripeAdapter>().InvoiceListAsync(Arg.Is<StripeInvoiceListOptions>(
options =>
options.Customer == provider.GatewayCustomerId)).Returns(invoices);
var result = await sutProvider.Sut.GetInvoicesAsync(provider.Id); var result = await sutProvider.Sut.GetInvoicesAsync(provider.Id);
@ -108,6 +181,27 @@ public class ProviderBillingControllerTests
#region GenerateClientInvoiceReportAsync #region GenerateClientInvoiceReportAsync
[Theory, BitAutoData]
public async Task GenerateClientInvoiceReportAsync_NullReportContent_ServerError(
Provider provider,
string invoiceId,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
sutProvider.GetDependency<IProviderBillingService>().GenerateClientInvoiceReport(invoiceId)
.ReturnsNull();
var result = await sutProvider.Sut.GenerateClientInvoiceReportAsync(provider.Id, invoiceId);
Assert.IsType<JsonHttpResult<ErrorResponseModel>>(result);
var response = (JsonHttpResult<ErrorResponseModel>)result;
Assert.Equal(StatusCodes.Status500InternalServerError, response.StatusCode);
Assert.Equal("We had a problem generating your invoice CSV. Please contact support.", response.Value.Message);
}
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GenerateClientInvoiceReportAsync_Ok( public async Task GenerateClientInvoiceReportAsync_Ok(
Provider provider, Provider provider,
@ -133,158 +227,6 @@ public class ProviderBillingControllerTests
#endregion #endregion
#region GetPaymentInformationAsync & TryGetBillableProviderForAdminOperation
[Theory, BitAutoData]
public async Task GetPaymentInformationAsync_FFDisabled_NotFound(
Guid providerId,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(false);
var result = await sutProvider.Sut.GetPaymentInformationAsync(providerId);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetPaymentInformationAsync_NullProvider_NotFound(
Guid providerId,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(providerId).ReturnsNull();
var result = await sutProvider.Sut.GetPaymentInformationAsync(providerId);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetPaymentInformationAsync_NotProviderUser_Unauthorized(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(provider.Id)
.Returns(false);
var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id);
Assert.IsType<UnauthorizedHttpResult>(result);
}
[Theory, BitAutoData]
public async Task GetPaymentInformationAsync_ProviderNotBillable_Unauthorized(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
.Returns(true);
provider.Type = ProviderType.Reseller;
provider.Status = ProviderStatusType.Created;
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
sutProvider.GetDependency<ICurrentContext>().ProviderProviderAdmin(provider.Id)
.Returns(true);
var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id);
Assert.IsType<UnauthorizedHttpResult>(result);
}
[Theory, BitAutoData]
public async Task GetPaymentInformation_PaymentInformationNull_NotFound(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
sutProvider.GetDependency<ISubscriberService>().GetPaymentInformation(provider).ReturnsNull();
var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetPaymentInformation_Ok(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
var maskedPaymentMethod = new MaskedPaymentMethodDTO(PaymentMethodType.Card, "VISA *1234", false);
var taxInformation =
new TaxInformationDTO("US", "12345", "123456789", "123 Example St.", null, "Example Town", "NY");
sutProvider.GetDependency<ISubscriberService>().GetPaymentInformation(provider).Returns(new PaymentInformationDTO(
100,
maskedPaymentMethod,
taxInformation));
var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id);
Assert.IsType<Ok<PaymentInformationResponse>>(result);
var response = ((Ok<PaymentInformationResponse>)result).Value;
Assert.Equal(100, response.AccountCredit);
Assert.Equal(maskedPaymentMethod.Description, response.PaymentMethod.Description);
Assert.Equal(taxInformation.TaxId, response.TaxInformation.TaxId);
}
#endregion
#region GetPaymentMethodAsync
[Theory, BitAutoData]
public async Task GetPaymentMethod_PaymentMethodNull_NotFound(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
sutProvider.GetDependency<ISubscriberService>().GetPaymentMethod(provider).ReturnsNull();
var result = await sutProvider.Sut.GetPaymentMethodAsync(provider.Id);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetPaymentMethod_Ok(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
sutProvider.GetDependency<ISubscriberService>().GetPaymentMethod(provider).Returns(new MaskedPaymentMethodDTO(
PaymentMethodType.Card, "Description", false));
var result = await sutProvider.Sut.GetPaymentMethodAsync(provider.Id);
Assert.IsType<Ok<MaskedPaymentMethodResponse>>(result);
var response = ((Ok<MaskedPaymentMethodResponse>)result).Value;
Assert.Equal(PaymentMethodType.Card, response.Type);
Assert.Equal("Description", response.Description);
Assert.False(response.NeedsVerification);
}
#endregion
#region GetSubscriptionAsync & TryGetBillableProviderForServiceUserOperation #region GetSubscriptionAsync & TryGetBillableProviderForServiceUserOperation
[Theory, BitAutoData] [Theory, BitAutoData]
@ -297,7 +239,7 @@ public class ProviderBillingControllerTests
var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
Assert.IsType<NotFound>(result); AssertNotFound(result);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
@ -312,7 +254,7 @@ public class ProviderBillingControllerTests
var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
Assert.IsType<NotFound>(result); AssertNotFound(result);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
@ -330,7 +272,7 @@ public class ProviderBillingControllerTests
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<UnauthorizedHttpResult>(result); AssertUnauthorized(result);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
@ -351,21 +293,7 @@ public class ProviderBillingControllerTests
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<UnauthorizedHttpResult>(result); AssertUnauthorized(result);
}
[Theory, BitAutoData]
public async Task GetSubscriptionAsync_NullConsolidatedBillingSubscription_NotFound(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableServiceUserInputs(provider, sutProvider);
sutProvider.GetDependency<IProviderBillingService>().GetConsolidatedBillingSubscription(provider).ReturnsNull();
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<NotFound>(result);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
@ -375,51 +303,83 @@ public class ProviderBillingControllerTests
{ {
ConfigureStableServiceUserInputs(provider, sutProvider); ConfigureStableServiceUserInputs(provider, sutProvider);
var configuredProviderPlans = new List<ConfiguredProviderPlanDTO> var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
{
new (Guid.NewGuid(), provider.Id, PlanType.TeamsMonthly, 50, 10, 30), var (thisYear, thisMonth, _) = DateTime.UtcNow;
new (Guid.NewGuid(), provider.Id , PlanType.EnterpriseMonthly, 100, 0, 90) var daysInThisMonth = DateTime.DaysInMonth(thisYear, thisMonth);
};
var subscription = new Subscription var subscription = new Subscription
{ {
Status = "unpaid", CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
CurrentPeriodEnd = new DateTime(2024, 6, 30), CurrentPeriodEnd = new DateTime(thisYear, thisMonth, daysInThisMonth),
Customer = new Customer Customer = new Customer
{ {
Balance = 100000, Address = new Address
Discount = new Discount
{ {
Coupon = new Coupon Country = "US",
{ PostalCode = "12345",
PercentOff = 10 Line1 = "123 Example St.",
} Line2 = "Unit 1",
} City = "Example Town",
State = "NY"
},
Balance = 100000,
Discount = new Discount { Coupon = new Coupon { PercentOff = 10 } },
TaxIds = new StripeList<TaxId> { Data = [new TaxId { Value = "123456789" }] }
},
Status = "unpaid",
};
stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, Arg.Is<SubscriptionGetOptions>(
options =>
options.Expand.Contains("customer.tax_ids") &&
options.Expand.Contains("test_clock"))).Returns(subscription);
var lastMonth = thisMonth - 1;
var daysInLastMonth = DateTime.DaysInMonth(thisYear, lastMonth);
var overdueInvoice = new Invoice
{
Id = "invoice_id",
Status = "open",
Created = new DateTime(thisYear, lastMonth, 1),
PeriodEnd = new DateTime(thisYear, lastMonth, daysInLastMonth),
Attempted = true
};
stripeAdapter.InvoiceSearchAsync(Arg.Is<InvoiceSearchOptions>(
options => options.Query == $"subscription:'{subscription.Id}' status:'open'"))
.Returns([overdueInvoice]);
var providerPlans = new List<ProviderPlan>
{
new ()
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.TeamsMonthly,
SeatMinimum = 50,
PurchasedSeats = 10,
AllocatedSeats = 60
},
new ()
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.EnterpriseMonthly,
SeatMinimum = 100,
PurchasedSeats = 0,
AllocatedSeats = 90
} }
}; };
var taxInformation = sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans);
new TaxInformationDTO("US", "12345", "123456789", "123 Example St.", null, "Example Town", "NY");
var suspension = new SubscriptionSuspensionDTO(
new DateTime(2024, 7, 30),
new DateTime(2024, 5, 30),
30);
var consolidatedBillingSubscription = new ConsolidatedBillingSubscriptionDTO(
configuredProviderPlans,
subscription,
taxInformation,
suspension);
sutProvider.GetDependency<IProviderBillingService>().GetConsolidatedBillingSubscription(provider)
.Returns(consolidatedBillingSubscription);
var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id);
Assert.IsType<Ok<ConsolidatedBillingSubscriptionResponse>>(result); Assert.IsType<Ok<ProviderSubscriptionResponse>>(result);
var response = ((Ok<ConsolidatedBillingSubscriptionResponse>)result).Value; var response = ((Ok<ProviderSubscriptionResponse>)result).Value;
Assert.Equal(subscription.Status, response.Status); Assert.Equal(subscription.Status, response.Status);
Assert.Equal(subscription.CurrentPeriodEnd, response.CurrentPeriodEndDate); Assert.Equal(subscription.CurrentPeriodEnd, response.CurrentPeriodEndDate);
@ -431,7 +391,7 @@ public class ProviderBillingControllerTests
Assert.NotNull(providerTeamsPlan); Assert.NotNull(providerTeamsPlan);
Assert.Equal(50, providerTeamsPlan.SeatMinimum); Assert.Equal(50, providerTeamsPlan.SeatMinimum);
Assert.Equal(10, providerTeamsPlan.PurchasedSeats); Assert.Equal(10, providerTeamsPlan.PurchasedSeats);
Assert.Equal(30, providerTeamsPlan.AssignedSeats); Assert.Equal(60, providerTeamsPlan.AssignedSeats);
Assert.Equal(60 * teamsPlan.PasswordManager.ProviderPortalSeatPrice, providerTeamsPlan.Cost); Assert.Equal(60 * teamsPlan.PasswordManager.ProviderPortalSeatPrice, providerTeamsPlan.Cost);
Assert.Equal("Monthly", providerTeamsPlan.Cadence); Assert.Equal("Monthly", providerTeamsPlan.Cadence);
@ -445,87 +405,46 @@ public class ProviderBillingControllerTests
Assert.Equal("Monthly", providerEnterprisePlan.Cadence); Assert.Equal("Monthly", providerEnterprisePlan.Cadence);
Assert.Equal(100000, response.AccountCredit); Assert.Equal(100000, response.AccountCredit);
Assert.Equal(taxInformation, response.TaxInformation);
var customer = subscription.Customer;
Assert.Equal(customer.Address.Country, response.TaxInformation.Country);
Assert.Equal(customer.Address.PostalCode, response.TaxInformation.PostalCode);
Assert.Equal(customer.TaxIds.First().Value, response.TaxInformation.TaxId);
Assert.Equal(customer.Address.Line1, response.TaxInformation.Line1);
Assert.Equal(customer.Address.Line2, response.TaxInformation.Line2);
Assert.Equal(customer.Address.City, response.TaxInformation.City);
Assert.Equal(customer.Address.State, response.TaxInformation.State);
Assert.Null(response.CancelAt); Assert.Null(response.CancelAt);
Assert.Equal(suspension, response.Suspension);
}
#endregion Assert.Equal(overdueInvoice.Created.AddDays(14), response.Suspension.SuspensionDate);
Assert.Equal(overdueInvoice.PeriodEnd, response.Suspension.UnpaidPeriodEndDate);
#region GetTaxInformationAsync Assert.Equal(14, response.Suspension.GracePeriod);
[Theory, BitAutoData]
public async Task GetTaxInformation_TaxInformationNull_NotFound(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
sutProvider.GetDependency<ISubscriberService>().GetTaxInformation(provider).ReturnsNull();
var result = await sutProvider.Sut.GetTaxInformationAsync(provider.Id);
Assert.IsType<NotFound>(result);
}
[Theory, BitAutoData]
public async Task GetTaxInformation_Ok(
Provider provider,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
sutProvider.GetDependency<ISubscriberService>().GetTaxInformation(provider).Returns(new TaxInformationDTO(
"US",
"12345",
"123456789",
"123 Example St.",
null,
"Example Town",
"NY"));
var result = await sutProvider.Sut.GetTaxInformationAsync(provider.Id);
Assert.IsType<Ok<TaxInformationResponse>>(result);
var response = ((Ok<TaxInformationResponse>)result).Value;
Assert.Equal("US", response.Country);
Assert.Equal("12345", response.PostalCode);
Assert.Equal("123456789", response.TaxId);
Assert.Equal("123 Example St.", response.Line1);
Assert.Null(response.Line2);
Assert.Equal("Example Town", response.City);
Assert.Equal("NY", response.State);
}
#endregion
#region UpdatePaymentMethodAsync
[Theory, BitAutoData]
public async Task UpdatePaymentMethod_Ok(
Provider provider,
TokenizedPaymentMethodRequestBody requestBody,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
await sutProvider.Sut.UpdatePaymentMethodAsync(provider.Id, requestBody);
await sutProvider.GetDependency<ISubscriberService>().Received(1).UpdatePaymentMethod(
provider, Arg.Is<TokenizedPaymentMethodDTO>(
options => options.Type == requestBody.Type && options.Token == requestBody.Token));
await sutProvider.GetDependency<IStripeAdapter>().Received(1).SubscriptionUpdateAsync(
provider.GatewaySubscriptionId, Arg.Is<SubscriptionUpdateOptions>(
options => options.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically));
} }
#endregion #endregion
#region UpdateTaxInformationAsync #region UpdateTaxInformationAsync
[Theory, BitAutoData]
public async Task UpdateTaxInformation_NoCountry_BadRequest(
Provider provider,
TaxInformationRequestBody requestBody,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
requestBody.Country = null;
var result = await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody);
Assert.IsType<BadRequest<ErrorResponseModel>>(result);
var response = (BadRequest<ErrorResponseModel>)result;
Assert.Equal("Country and postal code are required to update your tax information.", response.Value.Message);
}
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdateTaxInformation_Ok( public async Task UpdateTaxInformation_Ok(
Provider provider, Provider provider,
@ -537,7 +456,7 @@ public class ProviderBillingControllerTests
await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody); await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody);
await sutProvider.GetDependency<ISubscriberService>().Received(1).UpdateTaxInformation( await sutProvider.GetDependency<ISubscriberService>().Received(1).UpdateTaxInformation(
provider, Arg.Is<TaxInformationDTO>( provider, Arg.Is<TaxInformation>(
options => options =>
options.Country == requestBody.Country && options.Country == requestBody.Country &&
options.PostalCode == requestBody.PostalCode && options.PostalCode == requestBody.PostalCode &&
@ -549,25 +468,4 @@ public class ProviderBillingControllerTests
} }
#endregion #endregion
#region VerifyBankAccount
[Theory, BitAutoData]
public async Task VerifyBankAccount_Ok(
Provider provider,
VerifyBankAccountRequestBody requestBody,
SutProvider<ProviderBillingController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
var result = await sutProvider.Sut.VerifyBankAccountAsync(provider.Id, requestBody);
Assert.IsType<Ok>(result);
await sutProvider.GetDependency<ISubscriberService>().Received(1).VerifyBankAccount(
provider,
(requestBody.Amount1, requestBody.Amount2));
}
#endregion
} }

View File

@ -39,38 +39,7 @@ public class ProviderClientsControllerTests
var result = await sutProvider.Sut.CreateAsync(provider.Id, requestBody); var result = await sutProvider.Sut.CreateAsync(provider.Id, requestBody);
Assert.IsType<UnauthorizedHttpResult>(result); AssertUnauthorized(result);
}
[Theory, BitAutoData]
public async Task CreateAsync_MissingClientOrganization_ServerError(
Provider provider,
CreateClientOrganizationRequestBody requestBody,
SutProvider<ProviderClientsController> sutProvider)
{
ConfigureStableAdminInputs(provider, sutProvider);
var user = new User();
sutProvider.GetDependency<IUserService>().GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
var clientOrganizationId = Guid.NewGuid();
sutProvider.GetDependency<IProviderService>().CreateOrganizationAsync(
provider.Id,
Arg.Any<OrganizationSignup>(),
requestBody.OwnerEmail,
user)
.Returns(new ProviderOrganization
{
OrganizationId = clientOrganizationId
});
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(clientOrganizationId).ReturnsNull();
var result = await sutProvider.Sut.CreateAsync(provider.Id, requestBody);
Assert.IsType<ProblemHttpResult>(result);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
@ -137,32 +106,11 @@ public class ProviderClientsControllerTests
var result = await sutProvider.Sut.UpdateAsync(provider.Id, providerOrganizationId, requestBody); var result = await sutProvider.Sut.UpdateAsync(provider.Id, providerOrganizationId, requestBody);
Assert.IsType<NotFound>(result); AssertNotFound(result);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdateAsync_NoOrganization_ServerError( public async Task UpdateAsync_AssignedSeats_Ok(
Provider provider,
Guid providerOrganizationId,
UpdateClientOrganizationRequestBody requestBody,
ProviderOrganization providerOrganization,
SutProvider<ProviderClientsController> sutProvider)
{
ConfigureStableServiceUserInputs(provider, sutProvider);
sutProvider.GetDependency<IProviderOrganizationRepository>().GetByIdAsync(providerOrganizationId)
.Returns(providerOrganization);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(providerOrganization.OrganizationId)
.ReturnsNull();
var result = await sutProvider.Sut.UpdateAsync(provider.Id, providerOrganizationId, requestBody);
Assert.IsType<ProblemHttpResult>(result);
}
[Theory, BitAutoData]
public async Task UpdateAsync_AssignedSeats_NoContent(
Provider provider, Provider provider,
Guid providerOrganizationId, Guid providerOrganizationId,
UpdateClientOrganizationRequestBody requestBody, UpdateClientOrganizationRequestBody requestBody,
@ -193,7 +141,7 @@ public class ProviderClientsControllerTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdateAsync_Name_NoContent( public async Task UpdateAsync_Name_Ok(
Provider provider, Provider provider,
Guid providerOrganizationId, Guid providerOrganizationId,
UpdateClientOrganizationRequestBody requestBody, UpdateClientOrganizationRequestBody requestBody,

View File

@ -4,14 +4,37 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Models.Api;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.HttpResults;
using NSubstitute; using NSubstitute;
using Xunit;
namespace Bit.Api.Test.Billing; namespace Bit.Api.Test.Billing;
public static class Utilities public static class Utilities
{ {
public static void AssertNotFound(IResult result)
{
Assert.IsType<NotFound<ErrorResponseModel>>(result);
var response = ((NotFound<ErrorResponseModel>)result).Value;
Assert.Equal("Resource not found.", response.Message);
}
public static void AssertUnauthorized(IResult result)
{
Assert.IsType<JsonHttpResult<ErrorResponseModel>>(result);
var response = (JsonHttpResult<ErrorResponseModel>)result;
Assert.Equal(StatusCodes.Status401Unauthorized, response.StatusCode);
Assert.Equal("Unauthorized.", response.Value.Message);
}
public static void ConfigureStableAdminInputs<T>( public static void ConfigureStableAdminInputs<T>(
Provider provider, Provider provider,
SutProvider<T> sutProvider) where T : BaseProviderController SutProvider<T> sutProvider) where T : BaseProviderController

View File

@ -5,7 +5,6 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models.BitStripe;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
@ -29,8 +28,9 @@ namespace Bit.Core.Test.Billing.Services;
public class SubscriberServiceTests public class SubscriberServiceTests
{ {
#region CancelSubscription #region CancelSubscription
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task CancelSubscription_SubscriptionInactive_ContactSupport( public async Task CancelSubscription_SubscriptionInactive_ThrowsBillingException(
Organization organization, Organization organization,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -45,7 +45,7 @@ public class SubscriberServiceTests
.SubscriptionGetAsync(organization.GatewaySubscriptionId) .SubscriptionGetAsync(organization.GatewaySubscriptionId)
.Returns(subscription); .Returns(subscription);
await ThrowsContactSupportAsync(() => await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.CancelSubscription(organization, new OffboardingSurveyResponse(), false)); sutProvider.Sut.CancelSubscription(organization, new OffboardingSurveyResponse(), false));
await stripeAdapter await stripeAdapter
@ -192,9 +192,11 @@ public class SubscriberServiceTests
.DidNotReceiveWithAnyArgs() .DidNotReceiveWithAnyArgs()
.SubscriptionCancelAsync(Arg.Any<string>(), Arg.Any<SubscriptionCancelOptions>()); ; .SubscriptionCancelAsync(Arg.Any<string>(), Arg.Any<SubscriptionCancelOptions>()); ;
} }
#endregion #endregion
#region GetCustomer #region GetCustomer
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetCustomer_NullSubscriber_ThrowsArgumentNullException( public async Task GetCustomer_NullSubscriber_ThrowsArgumentNullException(
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
@ -256,9 +258,11 @@ public class SubscriberServiceTests
Assert.Equivalent(customer, gotCustomer); Assert.Equivalent(customer, gotCustomer);
} }
#endregion #endregion
#region GetCustomerOrThrow #region GetCustomerOrThrow
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetCustomerOrThrow_NullSubscriber_ThrowsArgumentNullException( public async Task GetCustomerOrThrow_NullSubscriber_ThrowsArgumentNullException(
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
@ -266,17 +270,17 @@ public class SubscriberServiceTests
async () => await sutProvider.Sut.GetCustomerOrThrow(null)); async () => await sutProvider.Sut.GetCustomerOrThrow(null));
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetCustomerOrThrow_NoGatewayCustomerId_ContactSupport( public async Task GetCustomerOrThrow_NoGatewayCustomerId_ThrowsBillingException(
Organization organization, Organization organization,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
organization.GatewayCustomerId = null; organization.GatewayCustomerId = null;
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization));
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetCustomerOrThrow_NoCustomer_ContactSupport( public async Task GetCustomerOrThrow_NoCustomer_ThrowsBillingException(
Organization organization, Organization organization,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -284,11 +288,11 @@ public class SubscriberServiceTests
.CustomerGetAsync(organization.GatewayCustomerId) .CustomerGetAsync(organization.GatewayCustomerId)
.ReturnsNull(); .ReturnsNull();
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization));
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetCustomerOrThrow_StripeException_ContactSupport( public async Task GetCustomerOrThrow_StripeException_ThrowsBillingException(
Organization organization, Organization organization,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -298,10 +302,10 @@ public class SubscriberServiceTests
.CustomerGetAsync(organization.GatewayCustomerId) .CustomerGetAsync(organization.GatewayCustomerId)
.ThrowsAsync(stripeException); .ThrowsAsync(stripeException);
await ThrowsContactSupportAsync( await ThrowsBillingExceptionAsync(
async () => await sutProvider.Sut.GetCustomerOrThrow(organization), async () => await sutProvider.Sut.GetCustomerOrThrow(organization),
"An error occurred while trying to retrieve a Stripe Customer", message: "An error occurred while trying to retrieve a Stripe customer",
stripeException); innerException: stripeException);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
@ -319,108 +323,6 @@ public class SubscriberServiceTests
Assert.Equivalent(customer, gotCustomer); Assert.Equivalent(customer, gotCustomer);
} }
#endregion
#region GetInvoices
[Theory, BitAutoData]
public async Task GetInvoices_NullSubscriber_ThrowsArgumentNullException(
SutProvider<SubscriberService> sutProvider)
=> await Assert.ThrowsAsync<ArgumentNullException>(
async () => await sutProvider.Sut.GetInvoices(null));
[Theory, BitAutoData]
public async Task GetCustomer_NoGatewayCustomerId_ReturnsEmptyList(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
organization.GatewayCustomerId = null;
var invoices = await sutProvider.Sut.GetInvoices(organization);
Assert.Empty(invoices);
}
[Theory, BitAutoData]
public async Task GetInvoices_StripeException_ReturnsEmptyList(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
sutProvider.GetDependency<IStripeAdapter>()
.InvoiceListAsync(Arg.Any<StripeInvoiceListOptions>())
.ThrowsAsync<StripeException>();
var invoices = await sutProvider.Sut.GetInvoices(organization);
Assert.Empty(invoices);
}
[Theory, BitAutoData]
public async Task GetInvoices_NullOptions_Succeeds(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
var invoices = new List<Invoice>
{
new ()
{
Created = new DateTime(2024, 6, 1),
Number = "2",
Status = "open",
Total = 100000,
HostedInvoiceUrl = "https://example.com/invoice/2",
InvoicePdf = "https://example.com/invoice/2/pdf"
},
new ()
{
Created = new DateTime(2024, 5, 1),
Number = "1",
Status = "paid",
Total = 100000,
HostedInvoiceUrl = "https://example.com/invoice/1",
InvoicePdf = "https://example.com/invoice/1/pdf"
}
};
sutProvider.GetDependency<IStripeAdapter>()
.InvoiceListAsync(Arg.Is<StripeInvoiceListOptions>(options => options.Customer == organization.GatewayCustomerId))
.Returns(invoices);
var gotInvoices = await sutProvider.Sut.GetInvoices(organization);
Assert.Equivalent(invoices, gotInvoices);
}
[Theory, BitAutoData]
public async Task GetInvoices_ProvidedOptions_Succeeds(
Organization organization,
SutProvider<SubscriberService> sutProvider)
{
var invoices = new List<Invoice>
{
new ()
{
Created = new DateTime(2024, 5, 1),
Number = "1",
Status = "paid",
Total = 100000,
}
};
sutProvider.GetDependency<IStripeAdapter>()
.InvoiceListAsync(Arg.Is<StripeInvoiceListOptions>(
options =>
options.Customer == organization.GatewayCustomerId &&
options.Status == "paid"))
.Returns(invoices);
var gotInvoices = await sutProvider.Sut.GetInvoices(organization, new StripeInvoiceListOptions
{
Status = "paid"
});
Assert.Equivalent(invoices, gotInvoices);
}
#endregion #endregion
@ -795,17 +697,17 @@ public class SubscriberServiceTests
async () => await sutProvider.Sut.GetSubscriptionOrThrow(null)); async () => await sutProvider.Sut.GetSubscriptionOrThrow(null));
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_NoGatewaySubscriptionId_ContactSupport( public async Task GetSubscriptionOrThrow_NoGatewaySubscriptionId_ThrowsBillingException(
Organization organization, Organization organization,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
organization.GatewaySubscriptionId = null; organization.GatewaySubscriptionId = null;
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization));
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_NoSubscription_ContactSupport( public async Task GetSubscriptionOrThrow_NoSubscription_ThrowsBillingException(
Organization organization, Organization organization,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -813,11 +715,11 @@ public class SubscriberServiceTests
.SubscriptionGetAsync(organization.GatewaySubscriptionId) .SubscriptionGetAsync(organization.GatewaySubscriptionId)
.ReturnsNull(); .ReturnsNull();
await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization));
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetSubscriptionOrThrow_StripeException_ContactSupport( public async Task GetSubscriptionOrThrow_StripeException_ThrowsBillingException(
Organization organization, Organization organization,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -827,10 +729,10 @@ public class SubscriberServiceTests
.SubscriptionGetAsync(organization.GatewaySubscriptionId) .SubscriptionGetAsync(organization.GatewaySubscriptionId)
.ThrowsAsync(stripeException); .ThrowsAsync(stripeException);
await ThrowsContactSupportAsync( await ThrowsBillingExceptionAsync(
async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization), async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization),
"An error occurred while trying to retrieve a Stripe Subscription", message: "An error occurred while trying to retrieve a Stripe subscription",
stripeException); innerException: stripeException);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
@ -911,12 +813,12 @@ public class SubscriberServiceTests
#region RemovePaymentMethod #region RemovePaymentMethod
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task RemovePaymentMethod_NullSubscriber_ArgumentNullException( public async Task RemovePaymentMethod_NullSubscriber_ThrowsArgumentNullException(
SutProvider<SubscriberService> sutProvider) => SutProvider<SubscriberService> sutProvider) =>
await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.RemovePaymentMethod(null)); await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.RemovePaymentMethod(null));
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task RemovePaymentMethod_Braintree_NoCustomer_ContactSupport( public async Task RemovePaymentMethod_Braintree_NoCustomer_ThrowsBillingException(
Organization organization, Organization organization,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -940,7 +842,7 @@ public class SubscriberServiceTests
braintreeGateway.Customer.Returns(customerGateway); braintreeGateway.Customer.Returns(customerGateway);
await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.RemovePaymentMethod(organization));
await customerGateway.Received(1).FindAsync(braintreeCustomerId); await customerGateway.Received(1).FindAsync(braintreeCustomerId);
@ -987,7 +889,7 @@ public class SubscriberServiceTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task RemovePaymentMethod_Braintree_CustomerUpdateFails_ContactSupport( public async Task RemovePaymentMethod_Braintree_CustomerUpdateFails_ThrowsBillingException(
Organization organization, Organization organization,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -1028,7 +930,7 @@ public class SubscriberServiceTests
Arg.Is<CustomerRequest>(request => request.DefaultPaymentMethodToken == null)) Arg.Is<CustomerRequest>(request => request.DefaultPaymentMethodToken == null))
.Returns(updateBraintreeCustomerResult); .Returns(updateBraintreeCustomerResult);
await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.RemovePaymentMethod(organization));
await customerGateway.Received(1).FindAsync(braintreeCustomerId); await customerGateway.Received(1).FindAsync(braintreeCustomerId);
@ -1042,7 +944,7 @@ public class SubscriberServiceTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task RemovePaymentMethod_Braintree_PaymentMethodDeleteFails_RollBack_ContactSupport( public async Task RemovePaymentMethod_Braintree_PaymentMethodDeleteFails_RollBack_ThrowsBillingException(
Organization organization, Organization organization,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -1086,7 +988,7 @@ public class SubscriberServiceTests
paymentMethodGateway.DeleteAsync(paymentMethod.Token).Returns(deleteBraintreePaymentMethodResult); paymentMethodGateway.DeleteAsync(paymentMethod.Token).Returns(deleteBraintreePaymentMethodResult);
await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.RemovePaymentMethod(organization));
await customerGateway.Received(1).FindAsync(braintreeCustomerId); await customerGateway.Received(1).FindAsync(braintreeCustomerId);
@ -1206,42 +1108,42 @@ public class SubscriberServiceTests
#region UpdatePaymentMethod #region UpdatePaymentMethod
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdatePaymentMethod_NullSubscriber_ArgumentNullException( public async Task UpdatePaymentMethod_NullSubscriber_ThrowsArgumentNullException(
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
=> await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.UpdatePaymentMethod(null, null)); => await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.UpdatePaymentMethod(null, null));
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdatePaymentMethod_NullTokenizedPaymentMethod_ArgumentNullException( public async Task UpdatePaymentMethod_NullTokenizedPaymentMethod_ThrowsArgumentNullException(
Provider provider, Provider provider,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
=> await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.UpdatePaymentMethod(provider, null)); => await Assert.ThrowsAsync<ArgumentNullException>(() => sutProvider.Sut.UpdatePaymentMethod(provider, null));
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdatePaymentMethod_NoToken_ContactSupport( public async Task UpdatePaymentMethod_NoToken_ThrowsBillingException(
Provider provider, Provider provider,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
sutProvider.GetDependency<IStripeAdapter>().CustomerGetAsync(provider.GatewayCustomerId) sutProvider.GetDependency<IStripeAdapter>().CustomerGetAsync(provider.GatewayCustomerId)
.Returns(new Customer()); .Returns(new Customer());
await ThrowsContactSupportAsync(() => await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.Card, null))); sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.Card, null)));
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdatePaymentMethod_UnsupportedPaymentMethod_ContactSupport( public async Task UpdatePaymentMethod_UnsupportedPaymentMethod_ThrowsBillingException(
Provider provider, Provider provider,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
sutProvider.GetDependency<IStripeAdapter>().CustomerGetAsync(provider.GatewayCustomerId) sutProvider.GetDependency<IStripeAdapter>().CustomerGetAsync(provider.GatewayCustomerId)
.Returns(new Customer()); .Returns(new Customer());
await ThrowsContactSupportAsync(() => await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.BitPay, "TOKEN"))); sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.BitPay, "TOKEN")));
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdatePaymentMethod_BankAccount_IncorrectNumberOfSetupIntentsForToken_ContactSupport( public async Task UpdatePaymentMethod_BankAccount_IncorrectNumberOfSetupIntentsForToken_ThrowsBillingException(
Provider provider, Provider provider,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -1253,7 +1155,7 @@ public class SubscriberServiceTests
stripeAdapter.SetupIntentList(Arg.Is<SetupIntentListOptions>(options => options.PaymentMethod == "TOKEN")) stripeAdapter.SetupIntentList(Arg.Is<SetupIntentListOptions>(options => options.PaymentMethod == "TOKEN"))
.Returns([new SetupIntent(), new SetupIntent()]); .Returns([new SetupIntent(), new SetupIntent()]);
await ThrowsContactSupportAsync(() => await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.BankAccount, "TOKEN"))); sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.BankAccount, "TOKEN")));
} }
@ -1348,7 +1250,7 @@ public class SubscriberServiceTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdatePaymentMethod_Braintree_NullCustomer_ContactSupport( public async Task UpdatePaymentMethod_Braintree_NullCustomer_ThrowsBillingException(
Provider provider, Provider provider,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -1368,13 +1270,13 @@ public class SubscriberServiceTests
customerGateway.FindAsync(braintreeCustomerId).ReturnsNull(); customerGateway.FindAsync(braintreeCustomerId).ReturnsNull();
await ThrowsContactSupportAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN")));
await paymentMethodGateway.DidNotReceiveWithAnyArgs().CreateAsync(Arg.Any<PaymentMethodRequest>()); await paymentMethodGateway.DidNotReceiveWithAnyArgs().CreateAsync(Arg.Any<PaymentMethodRequest>());
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdatePaymentMethod_Braintree_ReplacePaymentMethod_CreatePaymentMethodFails_ContactSupport( public async Task UpdatePaymentMethod_Braintree_ReplacePaymentMethod_CreatePaymentMethodFails_ThrowsBillingException(
Provider provider, Provider provider,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -1406,13 +1308,13 @@ public class SubscriberServiceTests
options => options.CustomerId == braintreeCustomerId && options.PaymentMethodNonce == "TOKEN")) options => options.CustomerId == braintreeCustomerId && options.PaymentMethodNonce == "TOKEN"))
.Returns(createPaymentMethodResult); .Returns(createPaymentMethodResult);
await ThrowsContactSupportAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN")));
await customerGateway.DidNotReceiveWithAnyArgs().UpdateAsync(Arg.Any<string>(), Arg.Any<CustomerRequest>()); await customerGateway.DidNotReceiveWithAnyArgs().UpdateAsync(Arg.Any<string>(), Arg.Any<CustomerRequest>());
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdatePaymentMethod_Braintree_ReplacePaymentMethod_UpdateCustomerFails_DeletePaymentMethod_ContactSupport( public async Task UpdatePaymentMethod_Braintree_ReplacePaymentMethod_UpdateCustomerFails_DeletePaymentMethod_ThrowsBillingException(
Provider provider, Provider provider,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -1458,7 +1360,7 @@ public class SubscriberServiceTests
options.DefaultPaymentMethodToken == createPaymentMethodResult.Target.Token)) options.DefaultPaymentMethodToken == createPaymentMethodResult.Target.Token))
.Returns(updateCustomerResult); .Returns(updateCustomerResult);
await ThrowsContactSupportAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN")));
await paymentMethodGateway.Received(1).DeleteAsync(createPaymentMethodResult.Target.Token); await paymentMethodGateway.Received(1).DeleteAsync(createPaymentMethodResult.Target.Token);
} }
@ -1531,7 +1433,7 @@ public class SubscriberServiceTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task UpdatePaymentMethod_Braintree_CreateCustomer_CustomerUpdateFails_ContactSupport( public async Task UpdatePaymentMethod_Braintree_CreateCustomer_CustomerUpdateFails_ThrowsBillingException(
Provider provider, Provider provider,
SutProvider<SubscriberService> sutProvider) SutProvider<SubscriberService> sutProvider)
{ {
@ -1564,7 +1466,7 @@ public class SubscriberServiceTests
options.PaymentMethodNonce == "TOKEN")) options.PaymentMethodNonce == "TOKEN"))
.Returns(createCustomerResult); .Returns(createCustomerResult);
await ThrowsContactSupportAsync(() => await ThrowsBillingExceptionAsync(() =>
sutProvider.Sut.UpdatePaymentMethod(provider, sutProvider.Sut.UpdatePaymentMethod(provider,
new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN")));
@ -1648,7 +1550,7 @@ public class SubscriberServiceTests
stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is<CustomerGetOptions>( stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is<CustomerGetOptions>(
options => options.Expand.Contains("tax_ids"))).Returns(customer); options => options.Expand.Contains("tax_ids"))).Returns(customer);
var taxInformation = new TaxInformationDTO( var taxInformation = new TaxInformation(
"US", "US",
"12345", "12345",
"123456789", "123456789",
@ -1685,9 +1587,9 @@ public class SubscriberServiceTests
() => sutProvider.Sut.VerifyBankAccount(null, (0, 0))); () => sutProvider.Sut.VerifyBankAccount(null, (0, 0)));
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task VerifyBankAccount_NoSetupIntentId_ContactSupport( public async Task VerifyBankAccount_NoSetupIntentId_ThrowsBillingException(
Provider provider, Provider provider,
SutProvider<SubscriberService> sutProvider) => await ThrowsContactSupportAsync(() => sutProvider.Sut.VerifyBankAccount(provider, (1, 1))); SutProvider<SubscriberService> sutProvider) => await ThrowsBillingExceptionAsync(() => sutProvider.Sut.VerifyBankAccount(provider, (1, 1)));
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task VerifyBankAccount_MakesCorrectInvocations( public async Task VerifyBankAccount_MakesCorrectInvocations(

View File

@ -1,23 +1,22 @@
using Bit.Core.Billing; using Bit.Core.Billing;
using Xunit; using Xunit;
using static Bit.Core.Billing.Utilities;
namespace Bit.Core.Test.Billing; namespace Bit.Core.Test.Billing;
public static class Utilities public static class Utilities
{ {
public static async Task ThrowsContactSupportAsync( public static async Task ThrowsBillingExceptionAsync(
Func<Task> function, Func<Task> function,
string internalMessage = null, string response = null,
string message = null,
Exception innerException = null) Exception innerException = null)
{ {
var contactSupport = ContactSupport(internalMessage, innerException); var expected = new BillingException(response, message, innerException);
var exception = await Assert.ThrowsAsync<BillingException>(function); var actual = await Assert.ThrowsAsync<BillingException>(function);
Assert.Equal(contactSupport.ClientFriendlyMessage, exception.ClientFriendlyMessage); Assert.Equal(expected.Response, actual.Response);
Assert.Equal(contactSupport.Message, exception.Message); Assert.Equal(expected.Message, actual.Message);
Assert.Equal(contactSupport.InnerException, exception.InnerException); Assert.Equal(expected.InnerException, actual.InnerException);
} }
} }