diff --git a/src/Api/AdminConsole/Controllers/OrganizationsController.cs b/src/Api/AdminConsole/Controllers/OrganizationsController.cs index 2a4ba3a1db..822f9635eb 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationsController.cs @@ -66,7 +66,7 @@ public class OrganizationsController : Controller private readonly IAddSecretsManagerSubscriptionCommand _addSecretsManagerSubscriptionCommand; private readonly IPushNotificationService _pushNotificationService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly IGetSubscriptionQuery _getSubscriptionQuery; + private readonly ISubscriberQueries _subscriberQueries; private readonly IReferenceEventService _referenceEventService; private readonly IOrganizationEnableCollectionEnhancementsCommand _organizationEnableCollectionEnhancementsCommand; @@ -93,7 +93,7 @@ public class OrganizationsController : Controller IAddSecretsManagerSubscriptionCommand addSecretsManagerSubscriptionCommand, IPushNotificationService pushNotificationService, ICancelSubscriptionCommand cancelSubscriptionCommand, - IGetSubscriptionQuery getSubscriptionQuery, + ISubscriberQueries subscriberQueries, IReferenceEventService referenceEventService, IOrganizationEnableCollectionEnhancementsCommand organizationEnableCollectionEnhancementsCommand) { @@ -119,7 +119,7 @@ public class OrganizationsController : Controller _addSecretsManagerSubscriptionCommand = addSecretsManagerSubscriptionCommand; _pushNotificationService = pushNotificationService; _cancelSubscriptionCommand = cancelSubscriptionCommand; - _getSubscriptionQuery = getSubscriptionQuery; + _subscriberQueries = subscriberQueries; _referenceEventService = referenceEventService; _organizationEnableCollectionEnhancementsCommand = organizationEnableCollectionEnhancementsCommand; } @@ -479,7 +479,7 @@ public class OrganizationsController : Controller throw new NotFoundException(); } - var subscription = await _getSubscriptionQuery.GetSubscription(organization); + var subscription = await _subscriberQueries.GetSubscriptionOrThrow(organization); await _cancelSubscriptionCommand.CancelSubscription(subscription, new OffboardingSurveyResponse diff --git a/src/Api/Auth/Controllers/AccountsController.cs b/src/Api/Auth/Controllers/AccountsController.cs index 29ede684be..5f1910fb28 100644 --- a/src/Api/Auth/Controllers/AccountsController.cs +++ b/src/Api/Auth/Controllers/AccountsController.cs @@ -69,7 +69,7 @@ public class AccountsController : Controller private readonly IRotateUserKeyCommand _rotateUserKeyCommand; private readonly IFeatureService _featureService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly IGetSubscriptionQuery _getSubscriptionQuery; + private readonly ISubscriberQueries _subscriberQueries; private readonly IReferenceEventService _referenceEventService; private readonly ICurrentContext _currentContext; @@ -104,7 +104,7 @@ public class AccountsController : Controller IRotateUserKeyCommand rotateUserKeyCommand, IFeatureService featureService, ICancelSubscriptionCommand cancelSubscriptionCommand, - IGetSubscriptionQuery getSubscriptionQuery, + ISubscriberQueries subscriberQueries, IReferenceEventService referenceEventService, ICurrentContext currentContext, IRotationValidator, IEnumerable> cipherValidator, @@ -133,7 +133,7 @@ public class AccountsController : Controller _rotateUserKeyCommand = rotateUserKeyCommand; _featureService = featureService; _cancelSubscriptionCommand = cancelSubscriptionCommand; - _getSubscriptionQuery = getSubscriptionQuery; + _subscriberQueries = subscriberQueries; _referenceEventService = referenceEventService; _currentContext = currentContext; _cipherValidator = cipherValidator; @@ -831,7 +831,7 @@ public class AccountsController : Controller throw new UnauthorizedAccessException(); } - var subscription = await _getSubscriptionQuery.GetSubscription(user); + var subscription = await _subscriberQueries.GetSubscriptionOrThrow(user); await _cancelSubscriptionCommand.CancelSubscription(subscription, new OffboardingSurveyResponse diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs new file mode 100644 index 0000000000..583a5937e4 --- /dev/null +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -0,0 +1,44 @@ +using Bit.Api.Billing.Models; +using Bit.Core; +using Bit.Core.Billing.Queries; +using Bit.Core.Context; +using Bit.Core.Services; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.Billing.Controllers; + +[Route("providers/{providerId:guid}/billing")] +[Authorize("Application")] +public class ProviderBillingController( + ICurrentContext currentContext, + IFeatureService featureService, + IProviderBillingQueries providerBillingQueries) : Controller +{ + [HttpGet("subscription")] + public async Task GetSubscriptionAsync([FromRoute] Guid providerId) + { + if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + { + return TypedResults.NotFound(); + } + + if (!currentContext.ProviderProviderAdmin(providerId)) + { + return TypedResults.Unauthorized(); + } + + var subscriptionData = await providerBillingQueries.GetSubscriptionData(providerId); + + if (subscriptionData == null) + { + return TypedResults.NotFound(); + } + + var (providerPlans, subscription) = subscriptionData; + + var providerSubscriptionDTO = ProviderSubscriptionDTO.From(providerPlans, subscription); + + return TypedResults.Ok(providerSubscriptionDTO); + } +} diff --git a/src/Api/Billing/Models/ProviderSubscriptionDTO.cs b/src/Api/Billing/Models/ProviderSubscriptionDTO.cs new file mode 100644 index 0000000000..0e8b8bfb1c --- /dev/null +++ b/src/Api/Billing/Models/ProviderSubscriptionDTO.cs @@ -0,0 +1,47 @@ +using Bit.Core.Billing.Models; +using Bit.Core.Utilities; +using Stripe; + +namespace Bit.Api.Billing.Models; + +public record ProviderSubscriptionDTO( + string Status, + DateTime CurrentPeriodEndDate, + decimal? DiscountPercentage, + IEnumerable Plans) +{ + private const string _annualCadence = "Annual"; + private const string _monthlyCadence = "Monthly"; + + public static ProviderSubscriptionDTO From( + IEnumerable providerPlans, + Subscription subscription) + { + var providerPlansDTO = providerPlans + .Select(providerPlan => + { + var plan = StaticStore.GetPlan(providerPlan.PlanType); + var cost = (providerPlan.SeatMinimum + providerPlan.PurchasedSeats) * plan.PasswordManager.SeatPrice; + var cadence = plan.IsAnnual ? _annualCadence : _monthlyCadence; + return new ProviderPlanDTO( + plan.Name, + providerPlan.SeatMinimum, + providerPlan.PurchasedSeats, + cost, + cadence); + }); + + return new ProviderSubscriptionDTO( + subscription.Status, + subscription.CurrentPeriodEnd, + subscription.Customer?.Discount?.Coupon?.PercentOff, + providerPlansDTO); + } +} + +public record ProviderPlanDTO( + string PlanName, + int SeatMinimum, + int PurchasedSeats, + decimal Cost, + string Cadence); diff --git a/src/Core/AdminConsole/Entities/Provider/Provider.cs b/src/Core/AdminConsole/Entities/Provider/Provider.cs index ee2b35ed90..e5b794e6b1 100644 --- a/src/Core/AdminConsole/Entities/Provider/Provider.cs +++ b/src/Core/AdminConsole/Entities/Provider/Provider.cs @@ -6,7 +6,7 @@ using Bit.Core.Utilities; namespace Bit.Core.AdminConsole.Entities.Provider; -public class Provider : ITableObject +public class Provider : ITableObject, ISubscriber { public Guid Id { get; set; } /// @@ -34,6 +34,26 @@ public class Provider : ITableObject public string GatewayCustomerId { get; set; } public string GatewaySubscriptionId { get; set; } + public string BillingEmailAddress() => BillingEmail?.ToLowerInvariant().Trim(); + + public string BillingName() => DisplayBusinessName(); + + public string SubscriberName() => DisplayName(); + + public string BraintreeCustomerIdPrefix() => "p"; + + public string BraintreeIdField() => "provider_id"; + + public string BraintreeCloudRegionField() => "region"; + + public bool IsOrganization() => false; + + public bool IsUser() => false; + + public string SubscriberType() => "Provider"; + + public bool IsExpired() => false; + public void SetNewId() { if (Id == default) diff --git a/src/Core/Billing/BillingException.cs b/src/Core/Billing/BillingException.cs new file mode 100644 index 0000000000..a6944b3ed6 --- /dev/null +++ b/src/Core/Billing/BillingException.cs @@ -0,0 +1,9 @@ +namespace Bit.Core.Billing; + +public class BillingException( + string clientFriendlyMessage, + string internalMessage = null, + Exception innerException = null) : Exception(internalMessage, innerException) +{ + public string ClientFriendlyMessage { get; set; } = clientFriendlyMessage; +} diff --git a/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs b/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs index b23880e650..88708d3d2e 100644 --- a/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs +++ b/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs @@ -1,7 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Models; using Bit.Core.Entities; -using Bit.Core.Exceptions; using Stripe; namespace Bit.Core.Billing.Commands; @@ -17,7 +16,6 @@ public interface ICancelSubscriptionCommand /// The or with the subscription to cancel. /// An DTO containing user-provided feedback on why they are cancelling the subscription. /// A flag indicating whether to cancel the subscription immediately or at the end of the subscription period. - /// Thrown when the provided subscription is already in an inactive state. Task CancelSubscription( Subscription subscription, OffboardingSurveyResponse offboardingSurveyResponse, diff --git a/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs b/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs index 62bf0d0926..e2be6f45eb 100644 --- a/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs +++ b/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs @@ -4,5 +4,12 @@ namespace Bit.Core.Billing.Commands; public interface IRemovePaymentMethodCommand { + /// + /// Attempts to remove an Organization's saved payment method. If the Stripe representing the + /// contains a valid "btCustomerId" key in its property, + /// this command will attempt to remove the Braintree . Otherwise, it will attempt to remove the + /// Stripe . + /// + /// The organization to remove the saved payment method for. Task RemovePaymentMethod(Organization organization); } diff --git a/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs b/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs index c5dbb6d927..be8479ea99 100644 --- a/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs +++ b/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs @@ -1,55 +1,41 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Enums; -using Bit.Core.Exceptions; using Bit.Core.Services; using Braintree; using Microsoft.Extensions.Logging; +using static Bit.Core.Billing.Utilities; + namespace Bit.Core.Billing.Commands.Implementations; -public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand +public class RemovePaymentMethodCommand( + IBraintreeGateway braintreeGateway, + ILogger logger, + IStripeAdapter stripeAdapter) + : IRemovePaymentMethodCommand { - private readonly IBraintreeGateway _braintreeGateway; - private readonly ILogger _logger; - private readonly IStripeAdapter _stripeAdapter; - - public RemovePaymentMethodCommand( - IBraintreeGateway braintreeGateway, - ILogger logger, - IStripeAdapter stripeAdapter) - { - _braintreeGateway = braintreeGateway; - _logger = logger; - _stripeAdapter = stripeAdapter; - } - public async Task RemovePaymentMethod(Organization organization) { - const string braintreeCustomerIdKey = "btCustomerId"; - - if (organization == null) - { - throw new ArgumentNullException(nameof(organization)); - } + ArgumentNullException.ThrowIfNull(organization); if (organization.Gateway is not GatewayType.Stripe || string.IsNullOrEmpty(organization.GatewayCustomerId)) { throw ContactSupport(); } - var stripeCustomer = await _stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, new Stripe.CustomerGetOptions + var stripeCustomer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, new Stripe.CustomerGetOptions { - Expand = new List { "invoice_settings.default_payment_method", "sources" } + Expand = ["invoice_settings.default_payment_method", "sources"] }); if (stripeCustomer == null) { - _logger.LogError("Could not find Stripe customer ({ID}) when removing payment method", organization.GatewayCustomerId); + logger.LogError("Could not find Stripe customer ({ID}) when removing payment method", organization.GatewayCustomerId); throw ContactSupport(); } - if (stripeCustomer.Metadata?.TryGetValue(braintreeCustomerIdKey, out var braintreeCustomerId) ?? false) + if (stripeCustomer.Metadata?.TryGetValue(BraintreeCustomerIdKey, out var braintreeCustomerId) ?? false) { await RemoveBraintreePaymentMethodAsync(braintreeCustomerId); } @@ -61,11 +47,11 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand private async Task RemoveBraintreePaymentMethodAsync(string braintreeCustomerId) { - var customer = await _braintreeGateway.Customer.FindAsync(braintreeCustomerId); + var customer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); if (customer == null) { - _logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId); + logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId); throw ContactSupport(); } @@ -74,27 +60,27 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand { var existingDefaultPaymentMethod = customer.DefaultPaymentMethod; - var updateCustomerResult = await _braintreeGateway.Customer.UpdateAsync( + var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync( braintreeCustomerId, new CustomerRequest { DefaultPaymentMethodToken = null }); if (!updateCustomerResult.IsSuccess()) { - _logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}", + logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}", braintreeCustomerId, updateCustomerResult.Message); throw ContactSupport(); } - var deletePaymentMethodResult = await _braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token); + var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token); if (!deletePaymentMethodResult.IsSuccess()) { - await _braintreeGateway.Customer.UpdateAsync( + await braintreeGateway.Customer.UpdateAsync( braintreeCustomerId, new CustomerRequest { DefaultPaymentMethodToken = existingDefaultPaymentMethod.Token }); - _logger.LogError( + logger.LogError( "Failed to delete Braintree payment method for Customer ({ID}), re-linked payment method. Message: {Message}", braintreeCustomerId, deletePaymentMethodResult.Message); @@ -103,7 +89,7 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand } else { - _logger.LogWarning("Tried to remove non-existent Braintree payment method for Customer ({ID})", braintreeCustomerId); + logger.LogWarning("Tried to remove non-existent Braintree payment method for Customer ({ID})", braintreeCustomerId); } } @@ -116,25 +102,23 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand switch (source) { case Stripe.BankAccount: - await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); + await stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); break; case Stripe.Card: - await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id); + await stripeAdapter.CardDeleteAsync(customer.Id, source.Id); break; } } } - var paymentMethods = _stripeAdapter.PaymentMethodListAutoPagingAsync(new Stripe.PaymentMethodListOptions + var paymentMethods = stripeAdapter.PaymentMethodListAutoPagingAsync(new Stripe.PaymentMethodListOptions { Customer = customer.Id }); await foreach (var paymentMethod in paymentMethods) { - await _stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id, new Stripe.PaymentMethodDetachOptions()); + await stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id, new Stripe.PaymentMethodDetachOptions()); } } - - private static GatewayException ContactSupport() => new("Could not remove your payment method. Please contact support for assistance."); } diff --git a/src/Core/Billing/Entities/ProviderPlan.cs b/src/Core/Billing/Entities/ProviderPlan.cs index 325dbbb156..2f15a539e1 100644 --- a/src/Core/Billing/Entities/ProviderPlan.cs +++ b/src/Core/Billing/Entities/ProviderPlan.cs @@ -11,7 +11,6 @@ public class ProviderPlan : ITableObject public PlanType PlanType { get; set; } public int? SeatMinimum { get; set; } public int? PurchasedSeats { get; set; } - public int? AllocatedSeats { get; set; } public void SetNewId() { @@ -20,4 +19,6 @@ public class ProviderPlan : ITableObject Id = CoreHelpers.GenerateComb(); } } + + public bool Configured => SeatMinimum.HasValue && PurchasedSeats.HasValue; } diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 113fa4d5b7..751bfdb671 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -17,6 +17,7 @@ public static class ServiceCollectionExtensions public static void AddBillingQueries(this IServiceCollection services) { - services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); } } diff --git a/src/Core/Billing/Models/ConfiguredProviderPlan.cs b/src/Core/Billing/Models/ConfiguredProviderPlan.cs new file mode 100644 index 0000000000..d5d53b36fa --- /dev/null +++ b/src/Core/Billing/Models/ConfiguredProviderPlan.cs @@ -0,0 +1,22 @@ +using Bit.Core.Billing.Entities; +using Bit.Core.Enums; + +namespace Bit.Core.Billing.Models; + +public record ConfiguredProviderPlan( + Guid Id, + Guid ProviderId, + PlanType PlanType, + int SeatMinimum, + int PurchasedSeats) +{ + public static ConfiguredProviderPlan From(ProviderPlan providerPlan) => + providerPlan.Configured + ? new ConfiguredProviderPlan( + providerPlan.Id, + providerPlan.ProviderId, + providerPlan.PlanType, + providerPlan.SeatMinimum.GetValueOrDefault(0), + providerPlan.PurchasedSeats.GetValueOrDefault(0)) + : null; +} diff --git a/src/Core/Billing/Models/ProviderSubscriptionData.cs b/src/Core/Billing/Models/ProviderSubscriptionData.cs new file mode 100644 index 0000000000..27da6cd226 --- /dev/null +++ b/src/Core/Billing/Models/ProviderSubscriptionData.cs @@ -0,0 +1,7 @@ +using Stripe; + +namespace Bit.Core.Billing.Models; + +public record ProviderSubscriptionData( + List ProviderPlans, + Subscription Subscription); diff --git a/src/Core/Billing/Queries/IGetSubscriptionQuery.cs b/src/Core/Billing/Queries/IGetSubscriptionQuery.cs deleted file mode 100644 index 9ba2a85ed5..0000000000 --- a/src/Core/Billing/Queries/IGetSubscriptionQuery.cs +++ /dev/null @@ -1,18 +0,0 @@ -using Bit.Core.Entities; -using Bit.Core.Exceptions; -using Stripe; - -namespace Bit.Core.Billing.Queries; - -public interface IGetSubscriptionQuery -{ - /// - /// Retrieves a Stripe using the 's property. - /// - /// The organization or user to retrieve the subscription for. - /// A Stripe . - /// Thrown when the is . - /// Thrown when the subscriber's is or empty. - /// Thrown when the returned from Stripe's API is null. - Task GetSubscription(ISubscriber subscriber); -} diff --git a/src/Core/Billing/Queries/IProviderBillingQueries.cs b/src/Core/Billing/Queries/IProviderBillingQueries.cs new file mode 100644 index 0000000000..1edfddaf56 --- /dev/null +++ b/src/Core/Billing/Queries/IProviderBillingQueries.cs @@ -0,0 +1,14 @@ +using Bit.Core.Billing.Models; + +namespace Bit.Core.Billing.Queries; + +public interface IProviderBillingQueries +{ + /// + /// Retrieves a provider's billing subscription data. + /// + /// The ID of the provider to retrieve subscription data for. + /// A object containing the provider's Stripe and their s. + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetSubscriptionData(Guid providerId); +} diff --git a/src/Core/Billing/Queries/ISubscriberQueries.cs b/src/Core/Billing/Queries/ISubscriberQueries.cs new file mode 100644 index 0000000000..ea6c0d985e --- /dev/null +++ b/src/Core/Billing/Queries/ISubscriberQueries.cs @@ -0,0 +1,30 @@ +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Stripe; + +namespace Bit.Core.Billing.Queries; + +public interface ISubscriberQueries +{ + /// + /// Retrieves a Stripe using the 's property. + /// + /// The organization, provider or user to retrieve the subscription for. + /// Optional parameters that can be passed to Stripe to expand or modify the . + /// A Stripe . + /// Thrown when the is . + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetSubscription( + ISubscriber subscriber, + SubscriptionGetOptions subscriptionGetOptions = null); + + /// + /// Retrieves a Stripe using the 's property. + /// + /// The organization or user to retrieve the subscription for. + /// A Stripe . + /// Thrown when the is . + /// Thrown when the subscriber's is or empty. + /// Thrown when the returned from Stripe's API is null. + Task GetSubscriptionOrThrow(ISubscriber subscriber); +} diff --git a/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs b/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs deleted file mode 100644 index c3b0a29552..0000000000 --- a/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs +++ /dev/null @@ -1,36 +0,0 @@ -using Bit.Core.Entities; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; - -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Queries.Implementations; - -public class GetSubscriptionQuery( - ILogger logger, - IStripeAdapter stripeAdapter) : IGetSubscriptionQuery -{ - public async Task GetSubscription(ISubscriber subscriber) - { - ArgumentNullException.ThrowIfNull(subscriber); - - if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) - { - logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); - - throw ContactSupport(); - } - - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - - if (subscription != null) - { - return subscription; - } - - logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); - - throw ContactSupport(); - } -} diff --git a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs new file mode 100644 index 0000000000..c921e82969 --- /dev/null +++ b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs @@ -0,0 +1,49 @@ +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Repositories; +using Microsoft.Extensions.Logging; +using Stripe; + +namespace Bit.Core.Billing.Queries.Implementations; + +public class ProviderBillingQueries( + ILogger logger, + IProviderPlanRepository providerPlanRepository, + IProviderRepository providerRepository, + ISubscriberQueries subscriberQueries) : IProviderBillingQueries +{ + public async Task GetSubscriptionData(Guid providerId) + { + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + logger.LogError( + "Could not find provider ({ID}) when retrieving subscription data.", + providerId); + + return null; + } + + var subscription = await subscriberQueries.GetSubscription(provider, new SubscriptionGetOptions + { + Expand = ["customer"] + }); + + if (subscription == null) + { + return null; + } + + var providerPlans = await providerPlanRepository.GetByProviderId(providerId); + + var configuredProviderPlans = providerPlans + .Where(providerPlan => providerPlan.Configured) + .Select(ConfiguredProviderPlan.From) + .ToList(); + + return new ProviderSubscriptionData( + configuredProviderPlans, + subscription); + } +} diff --git a/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs b/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs new file mode 100644 index 0000000000..a160a87595 --- /dev/null +++ b/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs @@ -0,0 +1,61 @@ +using Bit.Core.Entities; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using Stripe; + +using static Bit.Core.Billing.Utilities; + +namespace Bit.Core.Billing.Queries.Implementations; + +public class SubscriberQueries( + ILogger logger, + IStripeAdapter stripeAdapter) : ISubscriberQueries +{ + public async Task GetSubscription( + ISubscriber subscriber, + SubscriptionGetOptions subscriptionGetOptions = null) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); + + return null; + } + + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + + if (subscription != null) + { + return subscription; + } + + logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); + + return null; + } + + public async Task GetSubscriptionOrThrow(ISubscriber subscriber) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); + + throw ContactSupport(); + } + + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + + if (subscription != null) + { + return subscription; + } + + logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); + + throw ContactSupport(); + } +} diff --git a/src/Core/Billing/Repositories/IProviderPlanRepository.cs b/src/Core/Billing/Repositories/IProviderPlanRepository.cs index ccfc6ee683..eccbad82bb 100644 --- a/src/Core/Billing/Repositories/IProviderPlanRepository.cs +++ b/src/Core/Billing/Repositories/IProviderPlanRepository.cs @@ -5,5 +5,5 @@ namespace Bit.Core.Billing.Repositories; public interface IProviderPlanRepository : IRepository { - Task GetByProviderId(Guid providerId); + Task> GetByProviderId(Guid providerId); } diff --git a/src/Core/Billing/Utilities.cs b/src/Core/Billing/Utilities.cs index 54ace07a70..2b06f1ea6c 100644 --- a/src/Core/Billing/Utilities.cs +++ b/src/Core/Billing/Utilities.cs @@ -1,8 +1,11 @@ -using Bit.Core.Exceptions; - -namespace Bit.Core.Billing; +namespace Bit.Core.Billing; public static class Utilities { - public static GatewayException ContactSupport() => new("Something went wrong with your request. Please contact support."); + public const string BraintreeCustomerIdKey = "btCustomerId"; + + public static BillingException ContactSupport( + string internalMessage = null, + Exception innerException = null) => new("Something went wrong with your request. Please contact support.", + internalMessage, innerException); } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 598a5c062b..2b8ff33211 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -130,6 +130,7 @@ public static class FeatureFlagKeys public const string PM5864DollarThreshold = "PM-5864-dollar-threshold"; public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email"; public const string ShowPaymentMethodWarningBanners = "show-payment-method-warning-banners"; + public const string EnableConsolidatedBilling = "enable-consolidated-billing"; public static List GetAllKeys() { diff --git a/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs b/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs index 761545a255..f8448f4198 100644 --- a/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs +++ b/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs @@ -14,7 +14,7 @@ public class ProviderPlanRepository( globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString), IProviderPlanRepository { - public async Task GetByProviderId(Guid providerId) + public async Task> GetByProviderId(Guid providerId) { var sqlConnection = new SqlConnection(ConnectionString); @@ -23,6 +23,6 @@ public class ProviderPlanRepository( new { ProviderId = providerId }, commandType: CommandType.StoredProcedure); - return results.FirstOrDefault(); + return results.ToArray(); } } diff --git a/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs b/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs index 2f9a707b27..386f7115d7 100644 --- a/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs +++ b/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs @@ -16,14 +16,17 @@ public class ProviderPlanRepository( mapper, context => context.ProviderPlans), IProviderPlanRepository { - public async Task GetByProviderId(Guid providerId) + public async Task> GetByProviderId(Guid providerId) { using var serviceScope = ServiceScopeFactory.CreateScope(); + var databaseContext = GetDatabaseContext(serviceScope); + var query = from providerPlan in databaseContext.ProviderPlans where providerPlan.ProviderId == providerId select providerPlan; - return await query.FirstOrDefaultAsync(); + + return await query.ToArrayAsync(); } } diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs index fdbcc17e46..9d3c7ebfe5 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -56,7 +56,7 @@ public class OrganizationsControllerTests : IDisposable private readonly IAddSecretsManagerSubscriptionCommand _addSecretsManagerSubscriptionCommand; private readonly IPushNotificationService _pushNotificationService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly IGetSubscriptionQuery _getSubscriptionQuery; + private readonly ISubscriberQueries _subscriberQueries; private readonly IReferenceEventService _referenceEventService; private readonly IOrganizationEnableCollectionEnhancementsCommand _organizationEnableCollectionEnhancementsCommand; @@ -86,7 +86,7 @@ public class OrganizationsControllerTests : IDisposable _addSecretsManagerSubscriptionCommand = Substitute.For(); _pushNotificationService = Substitute.For(); _cancelSubscriptionCommand = Substitute.For(); - _getSubscriptionQuery = Substitute.For(); + _subscriberQueries = Substitute.For(); _referenceEventService = Substitute.For(); _organizationEnableCollectionEnhancementsCommand = Substitute.For(); @@ -113,7 +113,7 @@ public class OrganizationsControllerTests : IDisposable _addSecretsManagerSubscriptionCommand, _pushNotificationService, _cancelSubscriptionCommand, - _getSubscriptionQuery, + _subscriberQueries, _referenceEventService, _organizationEnableCollectionEnhancementsCommand); } diff --git a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs index 79aa2ca13d..4af60689c3 100644 --- a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs @@ -57,7 +57,7 @@ public class AccountsControllerTests : IDisposable private readonly IRotateUserKeyCommand _rotateUserKeyCommand; private readonly IFeatureService _featureService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly IGetSubscriptionQuery _getSubscriptionQuery; + private readonly ISubscriberQueries _subscriberQueries; private readonly IReferenceEventService _referenceEventService; private readonly ICurrentContext _currentContext; @@ -90,7 +90,7 @@ public class AccountsControllerTests : IDisposable _rotateUserKeyCommand = Substitute.For(); _featureService = Substitute.For(); _cancelSubscriptionCommand = Substitute.For(); - _getSubscriptionQuery = Substitute.For(); + _subscriberQueries = Substitute.For(); _referenceEventService = Substitute.For(); _currentContext = Substitute.For(); _cipherValidator = @@ -122,7 +122,7 @@ public class AccountsControllerTests : IDisposable _rotateUserKeyCommand, _featureService, _cancelSubscriptionCommand, - _getSubscriptionQuery, + _subscriberQueries, _referenceEventService, _currentContext, _cipherValidator, diff --git a/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs index 5de14f006f..968bfeb84d 100644 --- a/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs +++ b/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs @@ -1,13 +1,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Commands.Implementations; using Bit.Core.Enums; -using Bit.Core.Exceptions; using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using NSubstitute.ReturnsExtensions; using Xunit; +using static Bit.Core.Test.Billing.Utilities; using BT = Braintree; using S = Stripe; @@ -355,13 +355,4 @@ public class RemovePaymentMethodCommandTests return (braintreeGateway, customerGateway, paymentMethodGateway); } - - private static async Task ThrowsContactSupportAsync(Func function) - { - const string message = "Could not remove your payment method. Please contact support for assistance."; - - var exception = await Assert.ThrowsAsync(function); - - Assert.Equal(message, exception.Message); - } } diff --git a/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs b/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs deleted file mode 100644 index adae46a791..0000000000 --- a/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs +++ /dev/null @@ -1,104 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Queries.Implementations; -using Bit.Core.Entities; -using Bit.Core.Exceptions; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Stripe; -using Xunit; - -namespace Bit.Core.Test.Billing.Queries; - -[SutProviderCustomize] -public class GetSubscriptionQueryTests -{ - [Theory, BitAutoData] - public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) - => await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetSubscription(null)); - - [Theory, BitAutoData] - public async Task GetSubscription_Organization_NoGatewaySubscriptionId_ThrowsGatewayException( - Organization organization, - SutProvider sutProvider) - { - organization.GatewaySubscriptionId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(organization)); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Organization_NoSubscription_ThrowsGatewayException( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(organization)); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Organization_Succeeds( - Organization organization, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscription(organization); - - Assert.Equivalent(subscription, gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_User_NoGatewaySubscriptionId_ThrowsGatewayException( - User user, - SutProvider sutProvider) - { - user.GatewaySubscriptionId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(user)); - } - - [Theory, BitAutoData] - public async Task GetSubscription_User_NoSubscription_ThrowsGatewayException( - User user, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(user)); - } - - [Theory, BitAutoData] - public async Task GetSubscription_User_Succeeds( - User user, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscription(user); - - Assert.Equivalent(subscription, gotSubscription); - } - - private static async Task ThrowsContactSupportAsync(Func function) - { - const string message = "Something went wrong with your request. Please contact support."; - - var exception = await Assert.ThrowsAsync(function); - - Assert.Equal(message, exception.Message); - } -} diff --git a/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs b/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs new file mode 100644 index 0000000000..0962ed32b1 --- /dev/null +++ b/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs @@ -0,0 +1,151 @@ +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Entities; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Queries.Implementations; +using Bit.Core.Billing.Repositories; +using Bit.Core.Enums; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Queries; + +[SutProviderCustomize] +public class ProviderBillingQueriesTests +{ + #region GetSubscriptionData + + [Theory, BitAutoData] + public async Task GetSubscriptionData_NullProvider_ReturnsNull( + SutProvider sutProvider, + Guid providerId) + { + var providerRepository = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).ReturnsNull(); + + var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); + + Assert.Null(subscriptionData); + + await providerRepository.Received(1).GetByIdAsync(providerId); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionData_NullSubscription_ReturnsNull( + SutProvider sutProvider, + Guid providerId, + Provider provider) + { + var providerRepository = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).Returns(provider); + + var subscriberQueries = sutProvider.GetDependency(); + + subscriberQueries.GetSubscription(provider).ReturnsNull(); + + var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); + + Assert.Null(subscriptionData); + + await providerRepository.Received(1).GetByIdAsync(providerId); + + await subscriberQueries.Received(1).GetSubscription( + provider, + Arg.Is( + options => options.Expand.Count == 1 && options.Expand.First() == "customer")); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionData_Success( + SutProvider sutProvider, + Guid providerId, + Provider provider) + { + var providerRepository = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).Returns(provider); + + var subscriberQueries = sutProvider.GetDependency(); + + var subscription = new Subscription(); + + subscriberQueries.GetSubscription(provider, Arg.Is( + options => options.Expand.Count == 1 && options.Expand.First() == "customer")).Returns(subscription); + + var providerPlanRepository = sutProvider.GetDependency(); + + var enterprisePlan = new ProviderPlan + { + Id = Guid.NewGuid(), + ProviderId = providerId, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0 + }; + + var teamsPlan = new ProviderPlan + { + Id = Guid.NewGuid(), + ProviderId = providerId, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 50, + PurchasedSeats = 10 + }; + + var providerPlans = new List + { + enterprisePlan, + teamsPlan, + }; + + providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans); + + var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); + + Assert.NotNull(subscriptionData); + + Assert.Equivalent(subscriptionData.Subscription, subscription); + + Assert.Equal(2, subscriptionData.ProviderPlans.Count); + + var configuredEnterprisePlan = + subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan => + configuredPlan.PlanType == PlanType.EnterpriseMonthly); + + var configuredTeamsPlan = + subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan => + configuredPlan.PlanType == PlanType.TeamsMonthly); + + Compare(enterprisePlan, configuredEnterprisePlan); + + Compare(teamsPlan, configuredTeamsPlan); + + await providerRepository.Received(1).GetByIdAsync(providerId); + + await subscriberQueries.Received(1).GetSubscription( + provider, + Arg.Is( + options => options.Expand.Count == 1 && options.Expand.First() == "customer")); + + await providerPlanRepository.Received(1).GetByProviderId(providerId); + + return; + + void Compare(ProviderPlan providerPlan, ConfiguredProviderPlan configuredProviderPlan) + { + Assert.NotNull(configuredProviderPlan); + Assert.Equal(providerPlan.Id, configuredProviderPlan.Id); + Assert.Equal(providerPlan.ProviderId, configuredProviderPlan.ProviderId); + Assert.Equal(providerPlan.SeatMinimum!.Value, configuredProviderPlan.SeatMinimum); + Assert.Equal(providerPlan.PurchasedSeats!.Value, configuredProviderPlan.PurchasedSeats); + } + } + #endregion +} diff --git a/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs b/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs new file mode 100644 index 0000000000..51682a6661 --- /dev/null +++ b/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs @@ -0,0 +1,263 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Queries.Implementations; +using Bit.Core.Entities; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; + +using static Bit.Core.Test.Billing.Utilities; + +namespace Bit.Core.Test.Billing.Queries; + +[SutProviderCustomize] +public class SubscriberQueriesTests +{ + #region GetSubscription + [Theory, BitAutoData] + public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetSubscription(null)); + + [Theory, BitAutoData] + public async Task GetSubscription_Organization_NoGatewaySubscriptionId_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = null; + + var gotSubscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Organization_NoSubscription_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .ReturnsNull(); + + var gotSubscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Organization_Succeeds( + Organization organization, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Equivalent(subscription, gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_User_NoGatewaySubscriptionId_ReturnsNull( + User user, + SutProvider sutProvider) + { + user.GatewaySubscriptionId = null; + + var gotSubscription = await sutProvider.Sut.GetSubscription(user); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_User_NoSubscription_ReturnsNull( + User user, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + .ReturnsNull(); + + var gotSubscription = await sutProvider.Sut.GetSubscription(user); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_User_Succeeds( + User user, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscription(user); + + Assert.Equivalent(subscription, gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Provider_NoGatewaySubscriptionId_ReturnsNull( + Provider provider, + SutProvider sutProvider) + { + provider.GatewaySubscriptionId = null; + + var gotSubscription = await sutProvider.Sut.GetSubscription(provider); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Provider_NoSubscription_ReturnsNull( + Provider provider, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) + .ReturnsNull(); + + var gotSubscription = await sutProvider.Sut.GetSubscription(provider); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Provider_Succeeds( + Provider provider, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscription(provider); + + Assert.Equivalent(subscription, gotSubscription); + } + #endregion + + #region GetSubscriptionOrThrow + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetSubscriptionOrThrow(null)); + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Organization_NoGatewaySubscriptionId_ThrowsGatewayException( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = null; + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Organization_NoSubscription_ThrowsGatewayException( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .ReturnsNull(); + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Organization_Succeeds( + Organization organization, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization); + + Assert.Equivalent(subscription, gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_User_NoGatewaySubscriptionId_ThrowsGatewayException( + User user, + SutProvider sutProvider) + { + user.GatewaySubscriptionId = null; + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_User_NoSubscription_ThrowsGatewayException( + User user, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + .ReturnsNull(); + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_User_Succeeds( + User user, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(user); + + Assert.Equivalent(subscription, gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Provider_NoGatewaySubscriptionId_ThrowsGatewayException( + Provider provider, + SutProvider sutProvider) + { + provider.GatewaySubscriptionId = null; + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Provider_NoSubscription_ThrowsGatewayException( + Provider provider, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) + .ReturnsNull(); + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Provider_Succeeds( + Provider provider, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(provider); + + Assert.Equivalent(subscription, gotSubscription); + } + #endregion +} diff --git a/test/Core.Test/Billing/Utilities.cs b/test/Core.Test/Billing/Utilities.cs index 359c010a29..ea9e6c694c 100644 --- a/test/Core.Test/Billing/Utilities.cs +++ b/test/Core.Test/Billing/Utilities.cs @@ -1,4 +1,4 @@ -using Bit.Core.Exceptions; +using Bit.Core.Billing; using Xunit; using static Bit.Core.Billing.Utilities; @@ -11,7 +11,7 @@ public static class Utilities { var contactSupport = ContactSupport(); - var exception = await Assert.ThrowsAsync(function); + var exception = await Assert.ThrowsAsync(function); Assert.Equal(contactSupport.Message, exception.Message); }