1
0
mirror of https://github.com/bitwarden/server.git synced 2025-06-30 23:52:50 -05:00

[AC-2888] Improve consolidated billing error handling (#4548)

* Fix error handling in provider setup process

This update ensures that when 'enable-consolidated-billing' is on, any exception thrown during the Stripe customer or subscription setup process for the provider will block the remainder of the setup process so the provider does not enter an invalid state

* Refactor the way BillingException is thrown

Made it simpler to just use the exception constructor and also ensured it was added to the exception handling middleware so it could provide a simple response to the client

* Handle all Stripe exceptions in exception handling middleware

* Fixed error response output for billing's provider controllers

* Cleaned up billing owned provider controllers

Changes were made based on feature updates by product and stuff that's no longer needed. No need to expose sensitive endpoints when they're not being used.

* Reafctored get invoices

Removed unnecssarily bloated method from SubscriberService

* Updated error handling for generating the client invoice report

* Moved get provider subscription to controller

This is only used once and the service layer doesn't seem like the correct choice anymore when thinking about error handling with retrieval

* Handled bad request for update tax information

* Split out Stripe configuration from unauthorization

* Run dotnet format

* Addison's feedback
This commit is contained in:
Alex Morask
2024-07-31 09:26:44 -04:00
committed by GitHub
parent 85ddd080cb
commit 398741cec4
33 changed files with 777 additions and 1260 deletions

View File

@ -3,7 +3,9 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Extensions;
using Bit.Core.Context;
using Bit.Core.Models.Api;
using Bit.Core.Services;
using Microsoft.AspNetCore.Http.HttpResults;
using Microsoft.AspNetCore.Mvc;
namespace Bit.Api.Billing.Controllers;
@ -11,8 +13,25 @@ namespace Bit.Api.Billing.Controllers;
public abstract class BaseProviderController(
ICurrentContext currentContext,
IFeatureService featureService,
IProviderRepository providerRepository) : Controller
ILogger<BaseProviderController> logger,
IProviderRepository providerRepository,
IUserService userService) : Controller
{
protected readonly IUserService UserService = userService;
protected static NotFound<ErrorResponseModel> NotFoundResponse() =>
TypedResults.NotFound(new ErrorResponseModel("Resource not found."));
protected static JsonHttpResult<ErrorResponseModel> ServerErrorResponse(string errorMessage) =>
TypedResults.Json(
new ErrorResponseModel(errorMessage),
statusCode: StatusCodes.Status500InternalServerError);
protected static JsonHttpResult<ErrorResponseModel> UnauthorizedResponse() =>
TypedResults.Json(
new ErrorResponseModel("Unauthorized."),
statusCode: StatusCodes.Status401Unauthorized);
protected Task<(Provider, IResult)> TryGetBillableProviderForAdminOperation(
Guid providerId) => TryGetBillableProviderAsync(providerId, currentContext.ProviderProviderAdmin);
@ -25,26 +44,53 @@ public abstract class BaseProviderController(
{
if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
{
return (null, TypedResults.NotFound());
logger.LogError(
"Cannot run Consolidated Billing operation for provider ({ProviderID}) while feature flag is disabled",
providerId);
return (null, NotFoundResponse());
}
var provider = await providerRepository.GetByIdAsync(providerId);
if (provider == null)
{
return (null, TypedResults.NotFound());
logger.LogError(
"Cannot find provider ({ProviderID}) for Consolidated Billing operation",
providerId);
return (null, NotFoundResponse());
}
if (!checkAuthorization(providerId))
{
return (null, TypedResults.Unauthorized());
var user = await UserService.GetUserByPrincipalAsync(User);
logger.LogError(
"User ({UserID}) is not authorized to perform Consolidated Billing operation for provider ({ProviderID})",
user?.Id, providerId);
return (null, UnauthorizedResponse());
}
if (!provider.IsBillable())
{
return (null, TypedResults.Unauthorized());
logger.LogError(
"Cannot run Consolidated Billing operation for provider ({ProviderID}) that is not billable",
providerId);
return (null, UnauthorizedResponse());
}
return (provider, null);
if (provider.IsStripeEnabled())
{
return (provider, null);
}
logger.LogError(
"Cannot run Consolidated Billing operation for provider ({ProviderID}) that is missing Stripe configuration",
providerId);
return (null, ServerErrorResponse("Something went wrong with your request. Please contact support."));
}
}

View File

@ -1,15 +1,19 @@
using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Models.Api;
using Bit.Core.Models.BitStripe;
using Bit.Core.Services;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using Stripe;
using static Bit.Core.Billing.Utilities;
namespace Bit.Api.Billing.Controllers;
[Route("providers/{providerId:guid}/billing")]
@ -17,10 +21,13 @@ namespace Bit.Api.Billing.Controllers;
public class ProviderBillingController(
ICurrentContext currentContext,
IFeatureService featureService,
ILogger<BaseProviderController> logger,
IProviderBillingService providerBillingService,
IProviderPlanRepository providerPlanRepository,
IProviderRepository providerRepository,
ISubscriberService subscriberService,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService) : BaseProviderController(currentContext, featureService, providerRepository)
IUserService userService) : BaseProviderController(currentContext, featureService, logger, providerRepository, userService)
{
[HttpGet("invoices")]
public async Task<IResult> GetInvoicesAsync([FromRoute] Guid providerId)
@ -32,7 +39,10 @@ public class ProviderBillingController(
return result;
}
var invoices = await subscriberService.GetInvoices(provider);
var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions
{
Customer = provider.GatewayCustomerId
});
var response = InvoicesResponse.From(invoices);
@ -53,7 +63,7 @@ public class ProviderBillingController(
if (reportContent == null)
{
return TypedResults.NotFound();
return ServerErrorResponse("We had a problem generating your invoice CSV. Please contact support.");
}
return TypedResults.File(
@ -61,95 +71,6 @@ public class ProviderBillingController(
"text/csv");
}
[HttpGet("payment-information")]
public async Task<IResult> GetPaymentInformationAsync([FromRoute] Guid providerId)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var paymentInformation = await subscriberService.GetPaymentInformation(provider);
if (paymentInformation == null)
{
return TypedResults.NotFound();
}
var response = PaymentInformationResponse.From(paymentInformation);
return TypedResults.Ok(response);
}
[HttpGet("payment-method")]
public async Task<IResult> GetPaymentMethodAsync([FromRoute] Guid providerId)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var maskedPaymentMethod = await subscriberService.GetPaymentMethod(provider);
if (maskedPaymentMethod == null)
{
return TypedResults.NotFound();
}
var response = MaskedPaymentMethodResponse.From(maskedPaymentMethod);
return TypedResults.Ok(response);
}
[HttpPut("payment-method")]
public async Task<IResult> UpdatePaymentMethodAsync(
[FromRoute] Guid providerId,
[FromBody] TokenizedPaymentMethodRequestBody requestBody)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var tokenizedPaymentMethod = new TokenizedPaymentMethodDTO(
requestBody.Type,
requestBody.Token);
await subscriberService.UpdatePaymentMethod(provider, tokenizedPaymentMethod);
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
new SubscriptionUpdateOptions
{
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically
});
return TypedResults.Ok();
}
[HttpPost]
[Route("payment-method/verify-bank-account")]
public async Task<IResult> VerifyBankAccountAsync(
[FromRoute] Guid providerId,
[FromBody] VerifyBankAccountRequestBody requestBody)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
await subscriberService.VerifyBankAccount(provider, (requestBody.Amount1, requestBody.Amount2));
return TypedResults.Ok();
}
[HttpGet("subscription")]
public async Task<IResult> GetSubscriptionAsync([FromRoute] Guid providerId)
{
@ -160,36 +81,20 @@ public class ProviderBillingController(
return result;
}
var consolidatedBillingSubscription = await providerBillingService.GetConsolidatedBillingSubscription(provider);
var subscription = await stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId,
new SubscriptionGetOptions { Expand = ["customer.tax_ids", "test_clock"] });
if (consolidatedBillingSubscription == null)
{
return TypedResults.NotFound();
}
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
var response = ConsolidatedBillingSubscriptionResponse.From(consolidatedBillingSubscription);
var taxInformation = GetTaxInformation(subscription.Customer);
return TypedResults.Ok(response);
}
var subscriptionSuspension = await GetSubscriptionSuspensionAsync(stripeAdapter, subscription);
[HttpGet("tax-information")]
public async Task<IResult> GetTaxInformationAsync([FromRoute] Guid providerId)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var taxInformation = await subscriberService.GetTaxInformation(provider);
if (taxInformation == null)
{
return TypedResults.NotFound();
}
var response = TaxInformationResponse.From(taxInformation);
var response = ProviderSubscriptionResponse.From(
subscription,
providerPlans,
taxInformation,
subscriptionSuspension);
return TypedResults.Ok(response);
}
@ -206,7 +111,13 @@ public class ProviderBillingController(
return result;
}
var taxInformation = new TaxInformationDTO(
if (requestBody is not { Country: not null, PostalCode: not null })
{
return TypedResults.BadRequest(
new ErrorResponseModel("Country and postal code are required to update your tax information."));
}
var taxInformation = new TaxInformation(
requestBody.Country,
requestBody.PostalCode,
requestBody.TaxId,

View File

@ -15,13 +15,13 @@ namespace Bit.Api.Billing.Controllers;
public class ProviderClientsController(
ICurrentContext currentContext,
IFeatureService featureService,
ILogger<ProviderClientsController> logger,
ILogger<BaseProviderController> logger,
IOrganizationRepository organizationRepository,
IProviderBillingService providerBillingService,
IProviderOrganizationRepository providerOrganizationRepository,
IProviderRepository providerRepository,
IProviderService providerService,
IUserService userService) : BaseProviderController(currentContext, featureService, providerRepository)
IUserService userService) : BaseProviderController(currentContext, featureService, logger, providerRepository, userService)
{
[HttpPost]
public async Task<IResult> CreateAsync(
@ -35,11 +35,11 @@ public class ProviderClientsController(
return result;
}
var user = await userService.GetUserByPrincipalAsync(User);
var user = await UserService.GetUserByPrincipalAsync(User);
if (user == null)
{
return TypedResults.Unauthorized();
return UnauthorizedResponse();
}
var organizationSignup = new OrganizationSignup
@ -63,13 +63,6 @@ public class ProviderClientsController(
var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId);
if (clientOrganization == null)
{
logger.LogError("Newly created client organization ({ID}) could not be found", providerOrganization.OrganizationId);
return TypedResults.Problem();
}
await providerBillingService.ScaleSeats(
provider,
requestBody.PlanType,
@ -103,18 +96,11 @@ public class ProviderClientsController(
if (providerOrganization == null)
{
return TypedResults.NotFound();
return NotFoundResponse();
}
var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId);
if (clientOrganization == null)
{
logger.LogError("The client organization ({OrganizationID}) represented by provider organization ({ProviderOrganizationID}) could not be found.", providerOrganization.OrganizationId, providerOrganization.Id);
return TypedResults.Problem();
}
if (clientOrganization.Seats != requestBody.AssignedSeats)
{
await providerBillingService.AssignSeatsToClientOrganization(

View File

@ -3,16 +3,16 @@
namespace Bit.Api.Billing.Models.Responses;
public record InvoicesResponse(
List<InvoiceDTO> Invoices)
List<InvoiceResponse> Invoices)
{
public static InvoicesResponse From(IEnumerable<Invoice> invoices) => new(
invoices
.Where(i => i.Status is "open" or "paid" or "uncollectible")
.OrderByDescending(i => i.Created)
.Select(InvoiceDTO.From).ToList());
.Select(InvoiceResponse.From).ToList());
}
public record InvoiceDTO(
public record InvoiceResponse(
string Id,
DateTime Date,
string Number,
@ -21,7 +21,7 @@ public record InvoiceDTO(
DateTime? DueDate,
string Url)
{
public static InvoiceDTO From(Invoice invoice) => new(
public static InvoiceResponse From(Invoice invoice) => new(
invoice.Id,
invoice.Created,
invoice.Number,

View File

@ -5,7 +5,7 @@ namespace Bit.Api.Billing.Models.Responses;
public record PaymentInformationResponse(
long AccountCredit,
MaskedPaymentMethodDTO PaymentMethod,
TaxInformationDTO TaxInformation)
TaxInformation TaxInformation)
{
public static PaymentInformationResponse From(PaymentInformationDTO paymentInformation) =>
new(

View File

@ -1,43 +1,48 @@
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Models;
using Bit.Core.Utilities;
using Stripe;
namespace Bit.Api.Billing.Models.Responses;
public record ConsolidatedBillingSubscriptionResponse(
public record ProviderSubscriptionResponse(
string Status,
DateTime CurrentPeriodEndDate,
decimal? DiscountPercentage,
string CollectionMethod,
IEnumerable<ProviderPlanResponse> Plans,
long AccountCredit,
TaxInformationDTO TaxInformation,
TaxInformation TaxInformation,
DateTime? CancelAt,
SubscriptionSuspensionDTO Suspension)
SubscriptionSuspension Suspension)
{
private const string _annualCadence = "Annual";
private const string _monthlyCadence = "Monthly";
public static ConsolidatedBillingSubscriptionResponse From(
ConsolidatedBillingSubscriptionDTO consolidatedBillingSubscription)
public static ProviderSubscriptionResponse From(
Subscription subscription,
ICollection<ProviderPlan> providerPlans,
TaxInformation taxInformation,
SubscriptionSuspension subscriptionSuspension)
{
var (providerPlans, subscription, taxInformation, suspension) = consolidatedBillingSubscription;
var providerPlanResponses = providerPlans
.Select(providerPlan =>
.Where(providerPlan => providerPlan.IsConfigured())
.Select(ConfiguredProviderPlan.From)
.Select(configuredProviderPlan =>
{
var plan = StaticStore.GetPlan(providerPlan.PlanType);
var cost = (providerPlan.SeatMinimum + providerPlan.PurchasedSeats) * plan.PasswordManager.ProviderPortalSeatPrice;
var plan = StaticStore.GetPlan(configuredProviderPlan.PlanType);
var cost = (configuredProviderPlan.SeatMinimum + configuredProviderPlan.PurchasedSeats) * plan.PasswordManager.ProviderPortalSeatPrice;
var cadence = plan.IsAnnual ? _annualCadence : _monthlyCadence;
return new ProviderPlanResponse(
plan.Name,
providerPlan.SeatMinimum,
providerPlan.PurchasedSeats,
providerPlan.AssignedSeats,
configuredProviderPlan.SeatMinimum,
configuredProviderPlan.PurchasedSeats,
configuredProviderPlan.AssignedSeats,
cost,
cadence);
});
return new ConsolidatedBillingSubscriptionResponse(
return new ProviderSubscriptionResponse(
subscription.Status,
subscription.CurrentPeriodEnd,
subscription.Customer?.Discount?.Coupon?.PercentOff,
@ -46,7 +51,7 @@ public record ConsolidatedBillingSubscriptionResponse(
subscription.Customer?.Balance ?? 0,
taxInformation,
subscription.CancelAt,
suspension);
subscriptionSuspension);
}
}

View File

@ -11,7 +11,7 @@ public record TaxInformationResponse(
string City,
string State)
{
public static TaxInformationResponse From(TaxInformationDTO taxInformation)
public static TaxInformationResponse From(TaxInformation taxInformation)
=> new(
taxInformation.Country,
taxInformation.PostalCode,