From 7f65a655d4dbfe578b8b5f982de34a6dc757562e Mon Sep 17 00:00:00 2001
From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com>
Date: Thu, 10 Jul 2025 08:32:25 -0500
Subject: [PATCH] [PM-21881] Manage payment details outside of checkout (#6032)
* Add feature flag
* Further establish billing command pattern and use in PreviewTaxAmountCommand
* Add billing address models/commands/queries/tests
* Update TypeReadingJsonConverter to account for new union types
* Add payment method models/commands/queries/tests
* Add credit models/commands/queries/tests
* Add command/query registrations
* Add new endpoints to support new command model and payment functionality
* Run dotnet format
* Add InjectUserAttribute for easier AccountBillilngVNextController handling
* Add InjectOrganizationAttribute for easier OrganizationBillingVNextController handling
* Add InjectProviderAttribute for easier ProviderBillingVNextController handling
* Add XML documentation for billing command pipeline
* Fix StripeConstants post-nullability
* More nullability cleanup
* Run dotnet format
---
.../Attributes/InjectOrganizationAttribute.cs | 61 +++
.../Attributes/InjectProviderAttribute.cs | 80 ++++
.../Billing/Attributes/InjectUserAttribute.cs | 53 +++
.../Controllers/BaseBillingController.cs | 44 +-
src/Api/Billing/Controllers/TaxController.cs | 5 +-
.../VNext/AccountBillingVNextController.cs | 64 +++
.../OrganizationBillingVNextController.cs | 107 +++++
.../VNext/ProviderBillingVNextController.cs | 97 +++++
.../Requests/Payment/BillingAddressRequest.cs | 20 +
.../Requests/Payment/BitPayCreditRequest.cs | 13 +
.../Payment/CheckoutBillingAddressRequest.cs | 24 ++
.../Payment/MinimalBillingAddressRequest.cs | 16 +
.../Payment/TokenizedPaymentMethodRequest.cs | 39 ++
.../Payment/VerifyBankAccountRequest.cs | 9 +
.../ManageOrganizationBillingRequirement.cs | 18 +
src/Api/Utilities/StringMatchesAttribute.cs | 18 +
src/Core/Billing/Commands/BillingCommand.cs | 62 +++
.../Billing/Commands/BillingCommandResult.cs | 31 ++
src/Core/Billing/Constants/StripeConstants.cs | 12 +-
.../Extensions/ServiceCollectionExtensions.cs | 2 +
.../Extensions/SubscriberExtensions.cs | 16 +-
.../Billing/Models/BillingCommandResult.cs | 36 --
.../Billing/Payment/Clients/BitPayClient.cs | 24 ++
.../CreateBitPayInvoiceForCreditCommand.cs | 59 +++
.../Commands/UpdateBillingAddressCommand.cs | 129 ++++++
.../Commands/UpdatePaymentMethodCommand.cs | 205 +++++++++
.../Commands/VerifyBankAccountCommand.cs | 63 +++
.../Billing/Payment/Models/BillingAddress.cs | 30 ++
.../Payment/Models/MaskedPaymentMethod.cs | 120 ++++++
.../Payment/Models/ProductUsageType.cs | 7 +
.../Models/TokenizablePaymentMethodType.cs | 8 +
.../Payment/Models/TokenizedPaymentMethod.cs | 8 +
.../Payment/Queries/GetBillingAddressQuery.cs | 41 ++
.../Billing/Payment/Queries/GetCreditQuery.cs | 26 ++
.../Payment/Queries/GetPaymentMethodQuery.cs | 96 +++++
src/Core/Billing/Payment/Registrations.cs | 24 ++
.../Pricing/JSON/TypeReadingJsonConverter.cs | 12 +-
.../Tax/Commands/PreviewTaxAmountCommand.cs | 172 ++++----
src/Core/Constants.cs | 1 +
.../InjectOrganizationAttributeTests.cs | 132 ++++++
.../InjectProviderAttributeTests.cs | 190 +++++++++
.../Attributes/InjectUserAttributesTests.cs | 129 ++++++
.../Billing/Extensions/StripeExtensions.cs | 18 +
...reateBitPayInvoiceForCreditCommandTests.cs | 94 +++++
.../UpdateBillingAddressCommandTests.cs | 349 +++++++++++++++
.../UpdatePaymentMethodCommandTests.cs | 399 ++++++++++++++++++
.../Commands/VerifyBankAccountCommandTests.cs | 81 ++++
.../Models/MaskedPaymentMethodTests.cs | 63 +++
.../Queries/GetBillingAddressQueryTests.cs | 204 +++++++++
.../Payment/Queries/GetCreditQueryTests.cs | 41 ++
.../Queries/GetPaymentMethodQueryTests.cs | 327 ++++++++++++++
.../Commands/PreviewTaxAmountCommandTests.cs | 72 +---
52 files changed, 3736 insertions(+), 215 deletions(-)
create mode 100644 src/Api/Billing/Attributes/InjectOrganizationAttribute.cs
create mode 100644 src/Api/Billing/Attributes/InjectProviderAttribute.cs
create mode 100644 src/Api/Billing/Attributes/InjectUserAttribute.cs
create mode 100644 src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs
create mode 100644 src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs
create mode 100644 src/Api/Billing/Controllers/VNext/ProviderBillingVNextController.cs
create mode 100644 src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs
create mode 100644 src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs
create mode 100644 src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs
create mode 100644 src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs
create mode 100644 src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs
create mode 100644 src/Api/Billing/Models/Requests/Payment/VerifyBankAccountRequest.cs
create mode 100644 src/Api/Billing/Models/Requirements/ManageOrganizationBillingRequirement.cs
create mode 100644 src/Api/Utilities/StringMatchesAttribute.cs
create mode 100644 src/Core/Billing/Commands/BillingCommand.cs
create mode 100644 src/Core/Billing/Commands/BillingCommandResult.cs
delete mode 100644 src/Core/Billing/Models/BillingCommandResult.cs
create mode 100644 src/Core/Billing/Payment/Clients/BitPayClient.cs
create mode 100644 src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs
create mode 100644 src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs
create mode 100644 src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs
create mode 100644 src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs
create mode 100644 src/Core/Billing/Payment/Models/BillingAddress.cs
create mode 100644 src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs
create mode 100644 src/Core/Billing/Payment/Models/ProductUsageType.cs
create mode 100644 src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs
create mode 100644 src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs
create mode 100644 src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs
create mode 100644 src/Core/Billing/Payment/Queries/GetCreditQuery.cs
create mode 100644 src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs
create mode 100644 src/Core/Billing/Payment/Registrations.cs
create mode 100644 test/Api.Test/Billing/Attributes/InjectOrganizationAttributeTests.cs
create mode 100644 test/Api.Test/Billing/Attributes/InjectProviderAttributeTests.cs
create mode 100644 test/Api.Test/Billing/Attributes/InjectUserAttributesTests.cs
create mode 100644 test/Core.Test/Billing/Extensions/StripeExtensions.cs
create mode 100644 test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs
create mode 100644 test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs
create mode 100644 test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs
create mode 100644 test/Core.Test/Billing/Payment/Commands/VerifyBankAccountCommandTests.cs
create mode 100644 test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs
create mode 100644 test/Core.Test/Billing/Payment/Queries/GetBillingAddressQueryTests.cs
create mode 100644 test/Core.Test/Billing/Payment/Queries/GetCreditQueryTests.cs
create mode 100644 test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs
diff --git a/src/Api/Billing/Attributes/InjectOrganizationAttribute.cs b/src/Api/Billing/Attributes/InjectOrganizationAttribute.cs
new file mode 100644
index 0000000000..f4c2a8c637
--- /dev/null
+++ b/src/Api/Billing/Attributes/InjectOrganizationAttribute.cs
@@ -0,0 +1,61 @@
+#nullable enable
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.Models.Api;
+using Bit.Core.Repositories;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.AspNetCore.Mvc.Filters;
+
+namespace Bit.Api.Billing.Attributes;
+
+///
+/// An action filter that facilitates the injection of a parameter into the executing action method arguments.
+///
+///
+/// This attribute retrieves the organization associated with the 'organizationId' included in the executing context's route data. If the organization cannot be found,
+/// the request is terminated with a not found response.
+/// The injected
+/// parameter must be marked with a [BindNever] attribute to short-circuit the model-binding system.
+///
+///
+/// EndpointAsync([BindNever] Organization organization)
+/// ]]>
+///
+///
+public class InjectOrganizationAttribute : ActionFilterAttribute
+{
+ public override async Task OnActionExecutionAsync(
+ ActionExecutingContext context,
+ ActionExecutionDelegate next)
+ {
+ if (!context.RouteData.Values.TryGetValue("organizationId", out var routeValue) ||
+ !Guid.TryParse(routeValue?.ToString(), out var organizationId))
+ {
+ context.Result = new BadRequestObjectResult(new ErrorResponseModel("Route parameter 'organizationId' is missing or invalid."));
+ return;
+ }
+
+ var organizationRepository = context.HttpContext.RequestServices
+ .GetRequiredService();
+
+ var organization = await organizationRepository.GetByIdAsync(organizationId);
+
+ if (organization == null)
+ {
+ context.Result = new NotFoundObjectResult(new ErrorResponseModel("Organization not found."));
+ return;
+ }
+
+ var organizationParameter = context.ActionDescriptor.Parameters
+ .FirstOrDefault(p => p.ParameterType == typeof(Organization));
+
+ if (organizationParameter != null)
+ {
+ context.ActionArguments[organizationParameter.Name] = organization;
+ }
+
+ await next();
+ }
+}
diff --git a/src/Api/Billing/Attributes/InjectProviderAttribute.cs b/src/Api/Billing/Attributes/InjectProviderAttribute.cs
new file mode 100644
index 0000000000..e65dda37c3
--- /dev/null
+++ b/src/Api/Billing/Attributes/InjectProviderAttribute.cs
@@ -0,0 +1,80 @@
+#nullable enable
+using Bit.Api.Models.Public.Response;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.AdminConsole.Enums.Provider;
+using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.Context;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.AspNetCore.Mvc.Filters;
+
+namespace Bit.Api.Billing.Attributes;
+
+///
+/// An action filter that facilitates the injection of a parameter into the executing action method arguments after performing an authorization check.
+///
+///
+/// This attribute retrieves the provider associated with the 'providerId' included in the executing context's route data. If the provider cannot be found,
+/// the request is terminated with a not-found response. It then checks the authorization level for the provider using the provided .
+/// If this check fails, the request is terminated with an unauthorized response.
+/// The injected
+/// parameter must be marked with a [BindNever] attribute to short-circuit the model-binding system.
+///
+///
+/// EndpointAsync([BindNever] Provider provider)
+/// ]]>
+///
+/// The desired access level for the authorization check.
+///
+public class InjectProviderAttribute(ProviderUserType providerUserType) : ActionFilterAttribute
+{
+ public override async Task OnActionExecutionAsync(
+ ActionExecutingContext context,
+ ActionExecutionDelegate next)
+ {
+ if (!context.RouteData.Values.TryGetValue("providerId", out var routeValue) ||
+ !Guid.TryParse(routeValue?.ToString(), out var providerId))
+ {
+ context.Result = new BadRequestObjectResult(new ErrorResponseModel("Route parameter 'providerId' is missing or invalid."));
+ return;
+ }
+
+ var providerRepository = context.HttpContext.RequestServices
+ .GetRequiredService();
+
+ var provider = await providerRepository.GetByIdAsync(providerId);
+
+ if (provider == null)
+ {
+ context.Result = new NotFoundObjectResult(new ErrorResponseModel("Provider not found."));
+ return;
+ }
+
+ var currentContext = context.HttpContext.RequestServices.GetRequiredService();
+
+ var unauthorized = providerUserType switch
+ {
+ ProviderUserType.ProviderAdmin => !currentContext.ProviderProviderAdmin(providerId),
+ ProviderUserType.ServiceUser => !currentContext.ProviderUser(providerId),
+ _ => false
+ };
+
+ if (unauthorized)
+ {
+ context.Result = new UnauthorizedObjectResult(new ErrorResponseModel("Unauthorized."));
+ return;
+ }
+
+ var providerParameter = context.ActionDescriptor.Parameters
+ .FirstOrDefault(p => p.ParameterType == typeof(Provider));
+
+ if (providerParameter != null)
+ {
+ context.ActionArguments[providerParameter.Name] = provider;
+ }
+
+ await next();
+ }
+}
diff --git a/src/Api/Billing/Attributes/InjectUserAttribute.cs b/src/Api/Billing/Attributes/InjectUserAttribute.cs
new file mode 100644
index 0000000000..0b614bdc44
--- /dev/null
+++ b/src/Api/Billing/Attributes/InjectUserAttribute.cs
@@ -0,0 +1,53 @@
+#nullable enable
+using Bit.Core.Entities;
+using Bit.Core.Models.Api;
+using Bit.Core.Services;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.AspNetCore.Mvc.Filters;
+
+namespace Bit.Api.Billing.Attributes;
+
+///
+/// An action filter that facilitates the injection of a parameter into the executing action method arguments.
+///
+///
+/// This attribute retrieves the authorized user associated with the current HTTP context using the service.
+/// If the user is unauthorized or cannot be found, the request is terminated with an unauthorized response.
+/// The injected
+/// parameter must be marked with a [BindNever] attribute to short-circuit the model-binding system.
+///
+///
+/// EndpointAsync([BindNever] User user)
+/// ]]>
+///
+///
+public class InjectUserAttribute : ActionFilterAttribute
+{
+ public override async Task OnActionExecutionAsync(
+ ActionExecutingContext context,
+ ActionExecutionDelegate next)
+ {
+ var userService = context.HttpContext.RequestServices.GetRequiredService();
+
+ var user = await userService.GetUserByPrincipalAsync(context.HttpContext.User);
+
+ if (user == null)
+ {
+ context.Result = new UnauthorizedObjectResult(new ErrorResponseModel("Unauthorized."));
+ return;
+ }
+
+ var userParameter =
+ context.ActionDescriptor.Parameters.FirstOrDefault(parameter => parameter.ParameterType == typeof(User));
+
+ if (userParameter != null)
+ {
+ context.ActionArguments[userParameter.Name] = user;
+ }
+
+ await next();
+ }
+}
diff --git a/src/Api/Billing/Controllers/BaseBillingController.cs b/src/Api/Billing/Controllers/BaseBillingController.cs
index 5f7005fdfc..057c8309fb 100644
--- a/src/Api/Billing/Controllers/BaseBillingController.cs
+++ b/src/Api/Billing/Controllers/BaseBillingController.cs
@@ -1,4 +1,6 @@
-using Bit.Core.Models.Api;
+#nullable enable
+using Bit.Core.Billing.Commands;
+using Bit.Core.Models.Api;
using Microsoft.AspNetCore.Http.HttpResults;
using Microsoft.AspNetCore.Mvc;
@@ -6,20 +8,50 @@ namespace Bit.Api.Billing.Controllers;
public abstract class BaseBillingController : Controller
{
+ ///
+ /// Processes the result of a billing command and converts it to an appropriate HTTP result response.
+ ///
+ ///
+ /// Result to response mappings:
+ ///
+ /// - : 200 OK
+ /// - : 400 BAD_REQUEST
+ /// - : 409 CONFLICT
+ /// - : 500 INTERNAL_SERVER_ERROR
+ ///
+ ///
+ /// The type of the successful result.
+ /// The result of executing the billing command.
+ /// An HTTP result response representing the outcome of the command execution.
+ protected static IResult Handle(BillingCommandResult result) =>
+ result.Match(
+ TypedResults.Ok,
+ badRequest => Error.BadRequest(badRequest.Response),
+ conflict => Error.Conflict(conflict.Response),
+ unhandled => Error.ServerError(unhandled.Response, unhandled.Exception));
+
protected static class Error
{
- public static BadRequest BadRequest(Dictionary> errors) =>
- TypedResults.BadRequest(new ErrorResponseModel(errors));
-
public static BadRequest BadRequest(string message) =>
TypedResults.BadRequest(new ErrorResponseModel(message));
+ public static JsonHttpResult Conflict(string message) =>
+ TypedResults.Json(
+ new ErrorResponseModel(message),
+ statusCode: StatusCodes.Status409Conflict);
+
public static NotFound NotFound() =>
TypedResults.NotFound(new ErrorResponseModel("Resource not found."));
- public static JsonHttpResult ServerError(string message = "Something went wrong with your request. Please contact support.") =>
+ public static JsonHttpResult ServerError(
+ string message = "Something went wrong with your request. Please contact support for assistance.",
+ Exception? exception = null) =>
TypedResults.Json(
- new ErrorResponseModel(message),
+ exception == null ? new ErrorResponseModel(message) : new ErrorResponseModel(message)
+ {
+ ExceptionMessage = exception.Message,
+ ExceptionStackTrace = exception.StackTrace
+ },
statusCode: StatusCodes.Status500InternalServerError);
public static JsonHttpResult Unauthorized(string message = "Unauthorized.") =>
diff --git a/src/Api/Billing/Controllers/TaxController.cs b/src/Api/Billing/Controllers/TaxController.cs
index 7b8b9d960f..d2c1c36726 100644
--- a/src/Api/Billing/Controllers/TaxController.cs
+++ b/src/Api/Billing/Controllers/TaxController.cs
@@ -28,9 +28,6 @@ public class TaxController(
var result = await previewTaxAmountCommand.Run(parameters);
- return result.Match(
- taxAmount => TypedResults.Ok(new { TaxAmount = taxAmount }),
- badRequest => Error.BadRequest(badRequest.TranslationKey),
- unhandled => Error.ServerError(unhandled.TranslationKey));
+ return Handle(result);
}
}
diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs
new file mode 100644
index 0000000000..e3b702e36d
--- /dev/null
+++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs
@@ -0,0 +1,64 @@
+#nullable enable
+using Bit.Api.Billing.Attributes;
+using Bit.Api.Billing.Models.Requests.Payment;
+using Bit.Core.Billing.Payment.Commands;
+using Bit.Core.Billing.Payment.Queries;
+using Bit.Core.Entities;
+using Bit.Core.Utilities;
+using Microsoft.AspNetCore.Authorization;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.AspNetCore.Mvc.ModelBinding;
+
+namespace Bit.Api.Billing.Controllers.VNext;
+
+[Authorize("Application")]
+[Route("account/billing/vnext")]
+[SelfHosted(NotSelfHostedOnly = true)]
+public class AccountBillingVNextController(
+ ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand,
+ IGetCreditQuery getCreditQuery,
+ IGetPaymentMethodQuery getPaymentMethodQuery,
+ IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingController
+{
+ [HttpGet("credit")]
+ [InjectUser]
+ public async Task GetCreditAsync(
+ [BindNever] User user)
+ {
+ var credit = await getCreditQuery.Run(user);
+ return TypedResults.Ok(credit);
+ }
+
+ [HttpPost("credit/bitpay")]
+ [InjectUser]
+ public async Task AddCreditViaBitPayAsync(
+ [BindNever] User user,
+ [FromBody] BitPayCreditRequest request)
+ {
+ var result = await createBitPayInvoiceForCreditCommand.Run(
+ user,
+ request.Amount,
+ request.RedirectUrl);
+ return Handle(result);
+ }
+
+ [HttpGet("payment-method")]
+ [InjectUser]
+ public async Task GetPaymentMethodAsync(
+ [BindNever] User user)
+ {
+ var paymentMethod = await getPaymentMethodQuery.Run(user);
+ return TypedResults.Ok(paymentMethod);
+ }
+
+ [HttpPut("payment-method")]
+ [InjectUser]
+ public async Task UpdatePaymentMethodAsync(
+ [BindNever] User user,
+ [FromBody] TokenizedPaymentMethodRequest request)
+ {
+ var (paymentMethod, billingAddress) = request.ToDomain();
+ var result = await updatePaymentMethodCommand.Run(user, paymentMethod, billingAddress);
+ return Handle(result);
+ }
+}
diff --git a/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs
new file mode 100644
index 0000000000..429f2065f6
--- /dev/null
+++ b/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs
@@ -0,0 +1,107 @@
+#nullable enable
+using Bit.Api.AdminConsole.Authorization;
+using Bit.Api.Billing.Attributes;
+using Bit.Api.Billing.Models.Requests.Payment;
+using Bit.Api.Billing.Models.Requirements;
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.Billing.Payment.Commands;
+using Bit.Core.Billing.Payment.Queries;
+using Bit.Core.Utilities;
+using Microsoft.AspNetCore.Authorization;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.AspNetCore.Mvc.ModelBinding;
+// ReSharper disable RouteTemplates.MethodMissingRouteParameters
+
+namespace Bit.Api.Billing.Controllers.VNext;
+
+[Authorize("Application")]
+[Route("organizations/{organizationId:guid}/billing/vnext")]
+[SelfHosted(NotSelfHostedOnly = true)]
+public class OrganizationBillingVNextController(
+ ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand,
+ IGetBillingAddressQuery getBillingAddressQuery,
+ IGetCreditQuery getCreditQuery,
+ IGetPaymentMethodQuery getPaymentMethodQuery,
+ IUpdateBillingAddressCommand updateBillingAddressCommand,
+ IUpdatePaymentMethodCommand updatePaymentMethodCommand,
+ IVerifyBankAccountCommand verifyBankAccountCommand) : BaseBillingController
+{
+ [Authorize]
+ [HttpGet("address")]
+ [InjectOrganization]
+ public async Task GetBillingAddressAsync(
+ [BindNever] Organization organization)
+ {
+ var billingAddress = await getBillingAddressQuery.Run(organization);
+ return TypedResults.Ok(billingAddress);
+ }
+
+ [Authorize]
+ [HttpPut("address")]
+ [InjectOrganization]
+ public async Task UpdateBillingAddressAsync(
+ [BindNever] Organization organization,
+ [FromBody] BillingAddressRequest request)
+ {
+ var billingAddress = request.ToDomain();
+ var result = await updateBillingAddressCommand.Run(organization, billingAddress);
+ return Handle(result);
+ }
+
+ [Authorize]
+ [HttpGet("credit")]
+ [InjectOrganization]
+ public async Task GetCreditAsync(
+ [BindNever] Organization organization)
+ {
+ var credit = await getCreditQuery.Run(organization);
+ return TypedResults.Ok(credit);
+ }
+
+ [Authorize]
+ [HttpPost("credit/bitpay")]
+ [InjectOrganization]
+ public async Task AddCreditViaBitPayAsync(
+ [BindNever] Organization organization,
+ [FromBody] BitPayCreditRequest request)
+ {
+ var result = await createBitPayInvoiceForCreditCommand.Run(
+ organization,
+ request.Amount,
+ request.RedirectUrl);
+ return Handle(result);
+ }
+
+ [Authorize]
+ [HttpGet("payment-method")]
+ [InjectOrganization]
+ public async Task GetPaymentMethodAsync(
+ [BindNever] Organization organization)
+ {
+ var paymentMethod = await getPaymentMethodQuery.Run(organization);
+ return TypedResults.Ok(paymentMethod);
+ }
+
+ [Authorize]
+ [HttpPut("payment-method")]
+ [InjectOrganization]
+ public async Task UpdatePaymentMethodAsync(
+ [BindNever] Organization organization,
+ [FromBody] TokenizedPaymentMethodRequest request)
+ {
+ var (paymentMethod, billingAddress) = request.ToDomain();
+ var result = await updatePaymentMethodCommand.Run(organization, paymentMethod, billingAddress);
+ return Handle(result);
+ }
+
+ [Authorize]
+ [HttpPost("payment-method/verify-bank-account")]
+ [InjectOrganization]
+ public async Task VerifyBankAccountAsync(
+ [BindNever] Organization organization,
+ [FromBody] VerifyBankAccountRequest request)
+ {
+ var result = await verifyBankAccountCommand.Run(organization, request.DescriptorCode);
+ return Handle(result);
+ }
+}
diff --git a/src/Api/Billing/Controllers/VNext/ProviderBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/ProviderBillingVNextController.cs
new file mode 100644
index 0000000000..be7963236f
--- /dev/null
+++ b/src/Api/Billing/Controllers/VNext/ProviderBillingVNextController.cs
@@ -0,0 +1,97 @@
+#nullable enable
+using Bit.Api.Billing.Attributes;
+using Bit.Api.Billing.Models.Requests.Payment;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.AdminConsole.Enums.Provider;
+using Bit.Core.Billing.Payment.Commands;
+using Bit.Core.Billing.Payment.Queries;
+using Bit.Core.Utilities;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.AspNetCore.Mvc.ModelBinding;
+// ReSharper disable RouteTemplates.MethodMissingRouteParameters
+
+namespace Bit.Api.Billing.Controllers.VNext;
+
+[Route("providers/{providerId:guid}/billing/vnext")]
+[SelfHosted(NotSelfHostedOnly = true)]
+public class ProviderBillingVNextController(
+ ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand,
+ IGetBillingAddressQuery getBillingAddressQuery,
+ IGetCreditQuery getCreditQuery,
+ IGetPaymentMethodQuery getPaymentMethodQuery,
+ IUpdateBillingAddressCommand updateBillingAddressCommand,
+ IUpdatePaymentMethodCommand updatePaymentMethodCommand,
+ IVerifyBankAccountCommand verifyBankAccountCommand) : BaseBillingController
+{
+ [HttpGet("address")]
+ [InjectProvider(ProviderUserType.ProviderAdmin)]
+ public async Task GetBillingAddressAsync(
+ [BindNever] Provider provider)
+ {
+ var billingAddress = await getBillingAddressQuery.Run(provider);
+ return TypedResults.Ok(billingAddress);
+ }
+
+ [HttpPut("address")]
+ [InjectProvider(ProviderUserType.ProviderAdmin)]
+ public async Task UpdateBillingAddressAsync(
+ [BindNever] Provider provider,
+ [FromBody] BillingAddressRequest request)
+ {
+ var billingAddress = request.ToDomain();
+ var result = await updateBillingAddressCommand.Run(provider, billingAddress);
+ return Handle(result);
+ }
+
+ [HttpGet("credit")]
+ [InjectProvider(ProviderUserType.ProviderAdmin)]
+ public async Task GetCreditAsync(
+ [BindNever] Provider provider)
+ {
+ var credit = await getCreditQuery.Run(provider);
+ return TypedResults.Ok(credit);
+ }
+
+ [HttpPost("credit/bitpay")]
+ [InjectProvider(ProviderUserType.ProviderAdmin)]
+ public async Task AddCreditViaBitPayAsync(
+ [BindNever] Provider provider,
+ [FromBody] BitPayCreditRequest request)
+ {
+ var result = await createBitPayInvoiceForCreditCommand.Run(
+ provider,
+ request.Amount,
+ request.RedirectUrl);
+ return Handle(result);
+ }
+
+ [HttpGet("payment-method")]
+ [InjectProvider(ProviderUserType.ProviderAdmin)]
+ public async Task GetPaymentMethodAsync(
+ [BindNever] Provider provider)
+ {
+ var paymentMethod = await getPaymentMethodQuery.Run(provider);
+ return TypedResults.Ok(paymentMethod);
+ }
+
+ [HttpPut("payment-method")]
+ [InjectProvider(ProviderUserType.ProviderAdmin)]
+ public async Task UpdatePaymentMethodAsync(
+ [BindNever] Provider provider,
+ [FromBody] TokenizedPaymentMethodRequest request)
+ {
+ var (paymentMethod, billingAddress) = request.ToDomain();
+ var result = await updatePaymentMethodCommand.Run(provider, paymentMethod, billingAddress);
+ return Handle(result);
+ }
+
+ [HttpPost("payment-method/verify-bank-account")]
+ [InjectProvider(ProviderUserType.ProviderAdmin)]
+ public async Task VerifyBankAccountAsync(
+ [BindNever] Provider provider,
+ [FromBody] VerifyBankAccountRequest request)
+ {
+ var result = await verifyBankAccountCommand.Run(provider, request.DescriptorCode);
+ return Handle(result);
+ }
+}
diff --git a/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs
new file mode 100644
index 0000000000..5c3c47f585
--- /dev/null
+++ b/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs
@@ -0,0 +1,20 @@
+#nullable enable
+using Bit.Core.Billing.Payment.Models;
+
+namespace Bit.Api.Billing.Models.Requests.Payment;
+
+public record BillingAddressRequest : CheckoutBillingAddressRequest
+{
+ public string? Line1 { get; set; }
+ public string? Line2 { get; set; }
+ public string? City { get; set; }
+ public string? State { get; set; }
+
+ public override BillingAddress ToDomain() => base.ToDomain() with
+ {
+ Line1 = Line1,
+ Line2 = Line2,
+ City = City,
+ State = State,
+ };
+}
diff --git a/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs b/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs
new file mode 100644
index 0000000000..bb6e7498d7
--- /dev/null
+++ b/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs
@@ -0,0 +1,13 @@
+#nullable enable
+using System.ComponentModel.DataAnnotations;
+
+namespace Bit.Api.Billing.Models.Requests.Payment;
+
+public record BitPayCreditRequest
+{
+ [Required]
+ public required decimal Amount { get; set; }
+
+ [Required]
+ public required string RedirectUrl { get; set; } = null!;
+}
diff --git a/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs
new file mode 100644
index 0000000000..54116e897d
--- /dev/null
+++ b/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs
@@ -0,0 +1,24 @@
+#nullable enable
+using System.ComponentModel.DataAnnotations;
+using Bit.Core.Billing.Payment.Models;
+
+namespace Bit.Api.Billing.Models.Requests.Payment;
+
+public record CheckoutBillingAddressRequest : MinimalBillingAddressRequest
+{
+ public TaxIdRequest? TaxId { get; set; }
+
+ public override BillingAddress ToDomain() => base.ToDomain() with
+ {
+ TaxId = TaxId != null ? new TaxID(TaxId.Code, TaxId.Value) : null
+ };
+
+ public class TaxIdRequest
+ {
+ [Required]
+ public string Code { get; set; } = null!;
+
+ [Required]
+ public string Value { get; set; } = null!;
+ }
+}
diff --git a/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs
new file mode 100644
index 0000000000..b4d28017d5
--- /dev/null
+++ b/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs
@@ -0,0 +1,16 @@
+#nullable enable
+using System.ComponentModel.DataAnnotations;
+using Bit.Core.Billing.Payment.Models;
+
+namespace Bit.Api.Billing.Models.Requests.Payment;
+
+public record MinimalBillingAddressRequest
+{
+ [Required]
+ [StringLength(2, MinimumLength = 2, ErrorMessage = "Country code must be 2 characters long.")]
+ public required string Country { get; set; } = null!;
+ [Required]
+ public required string PostalCode { get; set; } = null!;
+
+ public virtual BillingAddress ToDomain() => new() { Country = Country, PostalCode = PostalCode, };
+}
diff --git a/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs
new file mode 100644
index 0000000000..663e4e7cd2
--- /dev/null
+++ b/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs
@@ -0,0 +1,39 @@
+#nullable enable
+using System.ComponentModel.DataAnnotations;
+using Bit.Api.Utilities;
+using Bit.Core.Billing.Payment.Models;
+
+namespace Bit.Api.Billing.Models.Requests.Payment;
+
+public class TokenizedPaymentMethodRequest
+{
+ [Required]
+ [StringMatches("bankAccount", "card", "payPal",
+ ErrorMessage = "Payment method type must be one of: bankAccount, card, payPal")]
+ public required string Type { get; set; }
+
+ [Required]
+ public required string Token { get; set; }
+
+ public MinimalBillingAddressRequest? BillingAddress { get; set; }
+
+ public (TokenizedPaymentMethod, BillingAddress?) ToDomain()
+ {
+ var paymentMethod = new TokenizedPaymentMethod
+ {
+ Type = Type switch
+ {
+ "bankAccount" => TokenizablePaymentMethodType.BankAccount,
+ "card" => TokenizablePaymentMethodType.Card,
+ "payPal" => TokenizablePaymentMethodType.PayPal,
+ _ => throw new InvalidOperationException(
+ $"Invalid value for {nameof(TokenizedPaymentMethod)}.{nameof(TokenizedPaymentMethod.Type)}")
+ },
+ Token = Token
+ };
+
+ var billingAddress = BillingAddress?.ToDomain();
+
+ return (paymentMethod, billingAddress);
+ }
+}
diff --git a/src/Api/Billing/Models/Requests/Payment/VerifyBankAccountRequest.cs b/src/Api/Billing/Models/Requests/Payment/VerifyBankAccountRequest.cs
new file mode 100644
index 0000000000..2b5d6a0cb1
--- /dev/null
+++ b/src/Api/Billing/Models/Requests/Payment/VerifyBankAccountRequest.cs
@@ -0,0 +1,9 @@
+using System.ComponentModel.DataAnnotations;
+
+namespace Bit.Api.Billing.Models.Requests.Payment;
+
+public class VerifyBankAccountRequest
+{
+ [Required]
+ public required string DescriptorCode { get; set; }
+}
diff --git a/src/Api/Billing/Models/Requirements/ManageOrganizationBillingRequirement.cs b/src/Api/Billing/Models/Requirements/ManageOrganizationBillingRequirement.cs
new file mode 100644
index 0000000000..4efdf0812a
--- /dev/null
+++ b/src/Api/Billing/Models/Requirements/ManageOrganizationBillingRequirement.cs
@@ -0,0 +1,18 @@
+#nullable enable
+using Bit.Api.AdminConsole.Authorization;
+using Bit.Core.Context;
+using Bit.Core.Enums;
+
+namespace Bit.Api.Billing.Models.Requirements;
+
+public class ManageOrganizationBillingRequirement : IOrganizationRequirement
+{
+ public async Task AuthorizeAsync(
+ CurrentContextOrganization? organizationClaims,
+ Func> isProviderUserForOrg)
+ => organizationClaims switch
+ {
+ { Type: OrganizationUserType.Owner } => true,
+ _ => await isProviderUserForOrg()
+ };
+}
diff --git a/src/Api/Utilities/StringMatchesAttribute.cs b/src/Api/Utilities/StringMatchesAttribute.cs
new file mode 100644
index 0000000000..28485aed40
--- /dev/null
+++ b/src/Api/Utilities/StringMatchesAttribute.cs
@@ -0,0 +1,18 @@
+using System.ComponentModel.DataAnnotations;
+
+namespace Bit.Api.Utilities;
+
+public class StringMatchesAttribute(params string[]? accepted) : ValidationAttribute
+{
+ public override bool IsValid(object? value)
+ {
+ if (value is not string str ||
+ accepted == null ||
+ accepted.Length == 0)
+ {
+ return false;
+ }
+
+ return accepted.Contains(str);
+ }
+}
diff --git a/src/Core/Billing/Commands/BillingCommand.cs b/src/Core/Billing/Commands/BillingCommand.cs
new file mode 100644
index 0000000000..e6c6375b62
--- /dev/null
+++ b/src/Core/Billing/Commands/BillingCommand.cs
@@ -0,0 +1,62 @@
+using Bit.Core.Billing.Constants;
+using Microsoft.Extensions.Logging;
+using Stripe;
+
+namespace Bit.Core.Billing.Commands;
+
+using static StripeConstants;
+
+public abstract class BillingCommand(
+ ILogger logger)
+{
+ protected string CommandName => GetType().Name;
+
+ ///
+ /// Executes the provided function within a predefined execution context, handling any exceptions that occur during the process.
+ ///
+ /// The type of the successful result expected from the provided function.
+ /// A function that performs an operation and returns a .
+ /// A task that represents the operation. The result provides a which may indicate success or an error outcome.
+ protected async Task> HandleAsync(
+ Func>> function)
+ {
+ try
+ {
+ return await function();
+ }
+ catch (StripeException stripeException) when (ErrorCodes.Get().Contains(stripeException.StripeError.Code))
+ {
+ return stripeException.StripeError.Code switch
+ {
+ ErrorCodes.CustomerTaxLocationInvalid =>
+ new BadRequest("Your location wasn't recognized. Please ensure your country and postal code are valid and try again."),
+
+ ErrorCodes.PaymentMethodMicroDepositVerificationAttemptsExceeded =>
+ new BadRequest("You have exceeded the number of allowed verification attempts. Please contact support for assistance."),
+
+ ErrorCodes.PaymentMethodMicroDepositVerificationDescriptorCodeMismatch =>
+ new BadRequest("The verification code you provided does not match the one sent to your bank account. Please try again."),
+
+ ErrorCodes.PaymentMethodMicroDepositVerificationTimeout =>
+ new BadRequest("Your bank account was not verified within the required time period. Please contact support for assistance."),
+
+ ErrorCodes.TaxIdInvalid =>
+ new BadRequest("The tax ID number you provided was invalid. Please try again or contact support for assistance."),
+
+ _ => new Unhandled(stripeException)
+ };
+ }
+ catch (StripeException stripeException)
+ {
+ logger.LogError(stripeException,
+ "{Command}: An error occurred while communicating with Stripe | Code = {Code}", CommandName,
+ stripeException.StripeError.Code);
+ return new Unhandled(stripeException);
+ }
+ catch (Exception exception)
+ {
+ logger.LogError(exception, "{Command}: An unknown error occurred during execution", CommandName);
+ return new Unhandled(exception);
+ }
+ }
+}
diff --git a/src/Core/Billing/Commands/BillingCommandResult.cs b/src/Core/Billing/Commands/BillingCommandResult.cs
new file mode 100644
index 0000000000..b69ad4bf12
--- /dev/null
+++ b/src/Core/Billing/Commands/BillingCommandResult.cs
@@ -0,0 +1,31 @@
+#nullable enable
+using OneOf;
+
+namespace Bit.Core.Billing.Commands;
+
+public record BadRequest(string Response);
+public record Conflict(string Response);
+public record Unhandled(Exception? Exception = null, string Response = "Something went wrong with your request. Please contact support for assistance.");
+
+///
+/// A union type representing the result of a billing command.
+///
+/// Choices include:
+///
+/// - : Success
+/// - : Invalid input
+/// - : A known, but unresolvable issue
+/// - : An unknown issue
+///
+///
+///
+/// The successful result type of the operation.
+public class BillingCommandResult : OneOfBase
+{
+ private BillingCommandResult(OneOf input) : base(input) { }
+
+ public static implicit operator BillingCommandResult(T output) => new(output);
+ public static implicit operator BillingCommandResult(BadRequest badRequest) => new(badRequest);
+ public static implicit operator BillingCommandResult(Conflict conflict) => new(conflict);
+ public static implicit operator BillingCommandResult(Unhandled unhandled) => new(unhandled);
+}
diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs
index 0cffad72d3..3aaa519d66 100644
--- a/src/Core/Billing/Constants/StripeConstants.cs
+++ b/src/Core/Billing/Constants/StripeConstants.cs
@@ -1,4 +1,6 @@
-namespace Bit.Core.Billing.Constants;
+using System.Reflection;
+
+namespace Bit.Core.Billing.Constants;
public static class StripeConstants
{
@@ -36,6 +38,13 @@ public static class StripeConstants
public const string PaymentMethodMicroDepositVerificationDescriptorCodeMismatch = "payment_method_microdeposit_verification_descriptor_code_mismatch";
public const string PaymentMethodMicroDepositVerificationTimeout = "payment_method_microdeposit_verification_timeout";
public const string TaxIdInvalid = "tax_id_invalid";
+
+ public static string[] Get() =>
+ typeof(ErrorCodes)
+ .GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy)
+ .Where(fi => fi is { IsLiteral: true, IsInitOnly: false } && fi.FieldType == typeof(string))
+ .Select(fi => (string)fi.GetValue(null)!)
+ .ToArray();
}
public static class InvoiceStatus
@@ -51,6 +60,7 @@ public static class StripeConstants
public const string InvoiceApproved = "invoice_approved";
public const string OrganizationId = "organizationId";
public const string ProviderId = "providerId";
+ public const string RetiredBraintreeCustomerId = "btCustomerId_old";
public const string UserId = "userId";
}
diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs
index 5c7a42e9b8..5f1a0668bd 100644
--- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs
+++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs
@@ -1,6 +1,7 @@
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Caches.Implementations;
using Bit.Core.Billing.Licenses.Extensions;
+using Bit.Core.Billing.Payment;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations;
@@ -27,5 +28,6 @@ public static class ServiceCollectionExtensions
services.AddLicenseServices();
services.AddPricingClient();
services.AddTransient();
+ services.AddPaymentOperations();
}
}
diff --git a/src/Core/Billing/Extensions/SubscriberExtensions.cs b/src/Core/Billing/Extensions/SubscriberExtensions.cs
index e322ed7317..fc804de224 100644
--- a/src/Core/Billing/Extensions/SubscriberExtensions.cs
+++ b/src/Core/Billing/Extensions/SubscriberExtensions.cs
@@ -1,4 +1,8 @@
-using Bit.Core.Entities;
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.Billing.Enums;
+using Bit.Core.Billing.Payment.Models;
+using Bit.Core.Entities;
namespace Bit.Core.Billing.Extensions;
@@ -23,4 +27,14 @@ public static class SubscriberExtensions
? subscriberName
: subscriberName[..30];
}
+
+ public static ProductUsageType GetProductUsageType(this ISubscriber subscriber)
+ => subscriber switch
+ {
+ User => ProductUsageType.Personal,
+ Organization organization when organization.PlanType.GetProductTier() is ProductTierType.Free or ProductTierType.Families => ProductUsageType.Personal,
+ Organization => ProductUsageType.Business,
+ Provider => ProductUsageType.Business,
+ _ => throw new ArgumentOutOfRangeException(nameof(subscriber))
+ };
}
diff --git a/src/Core/Billing/Models/BillingCommandResult.cs b/src/Core/Billing/Models/BillingCommandResult.cs
deleted file mode 100644
index 1b8eefe8df..0000000000
--- a/src/Core/Billing/Models/BillingCommandResult.cs
+++ /dev/null
@@ -1,36 +0,0 @@
-using OneOf;
-
-namespace Bit.Core.Billing.Models;
-
-public record BadRequest(string TranslationKey)
-{
- public static BadRequest TaxIdNumberInvalid => new(BillingErrorTranslationKeys.TaxIdInvalid);
- public static BadRequest TaxLocationInvalid => new(BillingErrorTranslationKeys.CustomerTaxLocationInvalid);
- public static BadRequest UnknownTaxIdType => new(BillingErrorTranslationKeys.UnknownTaxIdType);
-}
-
-public record Unhandled(string TranslationKey = BillingErrorTranslationKeys.UnhandledError);
-
-public class BillingCommandResult : OneOfBase
-{
- private BillingCommandResult(OneOf input) : base(input) { }
-
- public static implicit operator BillingCommandResult(T output) => new(output);
- public static implicit operator BillingCommandResult(BadRequest badRequest) => new(badRequest);
- public static implicit operator BillingCommandResult(Unhandled unhandled) => new(unhandled);
-}
-
-public static class BillingErrorTranslationKeys
-{
- // "The tax ID number you provided was invalid. Please try again or contact support."
- public const string TaxIdInvalid = "taxIdInvalid";
-
- // "Your location wasn't recognized. Please ensure your country and postal code are valid and try again."
- public const string CustomerTaxLocationInvalid = "customerTaxLocationInvalid";
-
- // "Something went wrong with your request. Please contact support."
- public const string UnhandledError = "unhandledBillingError";
-
- // "We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support."
- public const string UnknownTaxIdType = "unknownTaxIdType";
-}
diff --git a/src/Core/Billing/Payment/Clients/BitPayClient.cs b/src/Core/Billing/Payment/Clients/BitPayClient.cs
new file mode 100644
index 0000000000..2cb8fb66ef
--- /dev/null
+++ b/src/Core/Billing/Payment/Clients/BitPayClient.cs
@@ -0,0 +1,24 @@
+using Bit.Core.Settings;
+using BitPayLight;
+using BitPayLight.Models.Invoice;
+
+namespace Bit.Core.Billing.Payment.Clients;
+
+public interface IBitPayClient
+{
+ Task GetInvoice(string invoiceId);
+ Task CreateInvoice(Invoice invoice);
+}
+
+public class BitPayClient(
+ GlobalSettings globalSettings) : IBitPayClient
+{
+ private readonly BitPay _bitPay = new(
+ globalSettings.BitPay.Token, globalSettings.BitPay.Production ? Env.Prod : Env.Test);
+
+ public Task GetInvoice(string invoiceId)
+ => _bitPay.GetInvoice(invoiceId);
+
+ public Task CreateInvoice(Invoice invoice)
+ => _bitPay.CreateInvoice(invoice);
+}
diff --git a/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs b/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs
new file mode 100644
index 0000000000..f61fa9d0f9
--- /dev/null
+++ b/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs
@@ -0,0 +1,59 @@
+#nullable enable
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.Billing.Commands;
+using Bit.Core.Billing.Payment.Clients;
+using Bit.Core.Entities;
+using Bit.Core.Settings;
+using BitPayLight.Models.Invoice;
+using Microsoft.Extensions.Logging;
+
+namespace Bit.Core.Billing.Payment.Commands;
+
+public interface ICreateBitPayInvoiceForCreditCommand
+{
+ Task> Run(
+ ISubscriber subscriber,
+ decimal amount,
+ string redirectUrl);
+}
+
+public class CreateBitPayInvoiceForCreditCommand(
+ IBitPayClient bitPayClient,
+ GlobalSettings globalSettings,
+ ILogger logger) : BillingCommand(logger), ICreateBitPayInvoiceForCreditCommand
+{
+ public Task> Run(
+ ISubscriber subscriber,
+ decimal amount,
+ string redirectUrl) => HandleAsync(async () =>
+ {
+ var (name, email, posData) = GetSubscriberInformation(subscriber);
+
+ var invoice = new Invoice
+ {
+ Buyer = new Buyer { Email = email, Name = name },
+ Currency = "USD",
+ ExtendedNotifications = true,
+ FullNotifications = true,
+ ItemDesc = "Bitwarden",
+ NotificationUrl = globalSettings.BitPay.NotificationUrl,
+ PosData = posData,
+ Price = Convert.ToDouble(amount),
+ RedirectUrl = redirectUrl
+ };
+
+ var created = await bitPayClient.CreateInvoice(invoice);
+ return created.Url;
+ });
+
+ private static (string? Name, string? Email, string POSData) GetSubscriberInformation(
+ ISubscriber subscriber) => subscriber switch
+ {
+ User user => (user.Email, user.Email, $"userId:{user.Id},accountCredit:1"),
+ Organization organization => (organization.Name, organization.BillingEmail,
+ $"organizationId:{organization.Id},accountCredit:1"),
+ Provider provider => (provider.Name, provider.BillingEmail, $"providerId:{provider.Id},accountCredit:1"),
+ _ => throw new ArgumentOutOfRangeException(nameof(subscriber))
+ };
+}
diff --git a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs
new file mode 100644
index 0000000000..adc534bd7d
--- /dev/null
+++ b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs
@@ -0,0 +1,129 @@
+#nullable enable
+using Bit.Core.Billing.Commands;
+using Bit.Core.Billing.Constants;
+using Bit.Core.Billing.Extensions;
+using Bit.Core.Billing.Payment.Models;
+using Bit.Core.Entities;
+using Bit.Core.Services;
+using Microsoft.Extensions.Logging;
+using Stripe;
+
+namespace Bit.Core.Billing.Payment.Commands;
+
+public interface IUpdateBillingAddressCommand
+{
+ Task> Run(
+ ISubscriber subscriber,
+ BillingAddress billingAddress);
+}
+
+public class UpdateBillingAddressCommand(
+ ILogger logger,
+ IStripeAdapter stripeAdapter) : BillingCommand(logger), IUpdateBillingAddressCommand
+{
+ public Task> Run(
+ ISubscriber subscriber,
+ BillingAddress billingAddress) => HandleAsync(() => subscriber.GetProductUsageType() switch
+ {
+ ProductUsageType.Personal => UpdatePersonalBillingAddressAsync(subscriber, billingAddress),
+ ProductUsageType.Business => UpdateBusinessBillingAddressAsync(subscriber, billingAddress)
+ });
+
+ private async Task> UpdatePersonalBillingAddressAsync(
+ ISubscriber subscriber,
+ BillingAddress billingAddress)
+ {
+ var customer =
+ await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId,
+ new CustomerUpdateOptions
+ {
+ Address = new AddressOptions
+ {
+ Country = billingAddress.Country,
+ PostalCode = billingAddress.PostalCode,
+ Line1 = billingAddress.Line1,
+ Line2 = billingAddress.Line2,
+ City = billingAddress.City,
+ State = billingAddress.State
+ },
+ Expand = ["subscriptions"]
+ });
+
+ await EnableAutomaticTaxAsync(subscriber, customer);
+
+ return BillingAddress.From(customer.Address);
+ }
+
+ private async Task> UpdateBusinessBillingAddressAsync(
+ ISubscriber subscriber,
+ BillingAddress billingAddress)
+ {
+ var customer =
+ await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId,
+ new CustomerUpdateOptions
+ {
+ Address = new AddressOptions
+ {
+ Country = billingAddress.Country,
+ PostalCode = billingAddress.PostalCode,
+ Line1 = billingAddress.Line1,
+ Line2 = billingAddress.Line2,
+ City = billingAddress.City,
+ State = billingAddress.State
+ },
+ Expand = ["subscriptions", "tax_ids"],
+ TaxExempt = billingAddress.Country != "US"
+ ? StripeConstants.TaxExempt.Reverse
+ : StripeConstants.TaxExempt.None
+ });
+
+ await EnableAutomaticTaxAsync(subscriber, customer);
+
+ var deleteExistingTaxIds = customer.TaxIds?.Any() ?? false
+ ? customer.TaxIds.Select(taxId => stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id)).ToList()
+ : [];
+
+ if (billingAddress.TaxId == null)
+ {
+ await Task.WhenAll(deleteExistingTaxIds);
+ return BillingAddress.From(customer.Address);
+ }
+
+ var updatedTaxId = await stripeAdapter.TaxIdCreateAsync(customer.Id,
+ new TaxIdCreateOptions { Type = billingAddress.TaxId.Code, Value = billingAddress.TaxId.Value });
+
+ if (billingAddress.TaxId.Code == StripeConstants.TaxIdType.SpanishNIF)
+ {
+ updatedTaxId = await stripeAdapter.TaxIdCreateAsync(customer.Id,
+ new TaxIdCreateOptions
+ {
+ Type = StripeConstants.TaxIdType.EUVAT,
+ Value = $"ES{billingAddress.TaxId.Value}"
+ });
+ }
+
+ await Task.WhenAll(deleteExistingTaxIds);
+
+ return BillingAddress.From(customer.Address, updatedTaxId);
+ }
+
+ private async Task EnableAutomaticTaxAsync(
+ ISubscriber subscriber,
+ Customer customer)
+ {
+ if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
+ {
+ var subscription = customer.Subscriptions.FirstOrDefault(subscription =>
+ subscription.Id == subscriber.GatewaySubscriptionId);
+
+ if (subscription is { AutomaticTax.Enabled: false })
+ {
+ await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId,
+ new SubscriptionUpdateOptions
+ {
+ AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
+ });
+ }
+ }
+ }
+}
diff --git a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs
new file mode 100644
index 0000000000..cda685d520
--- /dev/null
+++ b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs
@@ -0,0 +1,205 @@
+#nullable enable
+using Bit.Core.Billing.Caches;
+using Bit.Core.Billing.Commands;
+using Bit.Core.Billing.Constants;
+using Bit.Core.Billing.Payment.Models;
+using Bit.Core.Billing.Services;
+using Bit.Core.Entities;
+using Bit.Core.Services;
+using Bit.Core.Settings;
+using Bit.Core.Utilities;
+using Braintree;
+using Microsoft.Extensions.Logging;
+using Stripe;
+using Customer = Stripe.Customer;
+
+namespace Bit.Core.Billing.Payment.Commands;
+
+public interface IUpdatePaymentMethodCommand
+{
+ Task> Run(
+ ISubscriber subscriber,
+ TokenizedPaymentMethod paymentMethod,
+ BillingAddress? billingAddress);
+}
+
+public class UpdatePaymentMethodCommand(
+ IBraintreeGateway braintreeGateway,
+ IGlobalSettings globalSettings,
+ ILogger logger,
+ ISetupIntentCache setupIntentCache,
+ IStripeAdapter stripeAdapter,
+ ISubscriberService subscriberService) : BillingCommand(logger), IUpdatePaymentMethodCommand
+{
+ private readonly ILogger _logger = logger;
+ private static readonly Conflict _conflict = new("We had a problem updating your payment method. Please contact support for assistance.");
+
+ public Task> Run(
+ ISubscriber subscriber,
+ TokenizedPaymentMethod paymentMethod,
+ BillingAddress? billingAddress) => HandleAsync(async () =>
+ {
+ var customer = await subscriberService.GetCustomer(subscriber);
+
+ var result = paymentMethod.Type switch
+ {
+ TokenizablePaymentMethodType.BankAccount => await AddBankAccountAsync(subscriber, customer, paymentMethod.Token),
+ TokenizablePaymentMethodType.Card => await AddCardAsync(customer, paymentMethod.Token),
+ TokenizablePaymentMethodType.PayPal => await AddPayPalAsync(subscriber, customer, paymentMethod.Token),
+ _ => new BadRequest($"Payment method type '{paymentMethod.Type}' is not supported.")
+ };
+
+ if (billingAddress != null && customer.Address is not { Country: not null, PostalCode: not null })
+ {
+ await stripeAdapter.CustomerUpdateAsync(customer.Id,
+ new CustomerUpdateOptions
+ {
+ Address = new AddressOptions
+ {
+ Country = billingAddress.Country,
+ PostalCode = billingAddress.PostalCode
+ }
+ });
+ }
+
+ return result;
+ });
+
+ private async Task> AddBankAccountAsync(
+ ISubscriber subscriber,
+ Customer customer,
+ string token)
+ {
+ var setupIntents = await stripeAdapter.SetupIntentList(new SetupIntentListOptions
+ {
+ Expand = ["data.payment_method"],
+ PaymentMethod = token
+ });
+
+ switch (setupIntents.Count)
+ {
+ case 0:
+ _logger.LogError("{Command}: Could not find setup intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id);
+ return _conflict;
+ case > 1:
+ _logger.LogError("{Command}: Found more than one set up intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id);
+ return _conflict;
+ }
+
+ var setupIntent = setupIntents.First();
+
+ await setupIntentCache.Set(subscriber.Id, setupIntent.Id);
+
+ await UnlinkBraintreeCustomerAsync(customer);
+
+ return MaskedPaymentMethod.From(setupIntent);
+ }
+
+ private async Task> AddCardAsync(
+ Customer customer,
+ string token)
+ {
+ var paymentMethod = await stripeAdapter.PaymentMethodAttachAsync(token, new PaymentMethodAttachOptions { Customer = customer.Id });
+
+ await stripeAdapter.CustomerUpdateAsync(customer.Id,
+ new CustomerUpdateOptions
+ {
+ InvoiceSettings = new CustomerInvoiceSettingsOptions { DefaultPaymentMethod = token }
+ });
+
+ await UnlinkBraintreeCustomerAsync(customer);
+
+ return MaskedPaymentMethod.From(paymentMethod.Card);
+ }
+
+ private async Task> AddPayPalAsync(
+ ISubscriber subscriber,
+ Customer customer,
+ string token)
+ {
+ Braintree.Customer braintreeCustomer;
+
+ if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId))
+ {
+ braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId);
+
+ await ReplaceBraintreePaymentMethodAsync(braintreeCustomer, token);
+ }
+ else
+ {
+ braintreeCustomer = await CreateBraintreeCustomerAsync(subscriber, token);
+
+ var metadata = new Dictionary
+ {
+ [StripeConstants.MetadataKeys.BraintreeCustomerId] = braintreeCustomer.Id
+ };
+
+ await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata });
+ }
+
+ var payPalAccount = braintreeCustomer.DefaultPaymentMethod as PayPalAccount;
+
+ return MaskedPaymentMethod.From(payPalAccount!);
+ }
+
+ private async Task CreateBraintreeCustomerAsync(
+ ISubscriber subscriber,
+ string token)
+ {
+ var braintreeCustomerId =
+ subscriber.BraintreeCustomerIdPrefix() +
+ subscriber.Id.ToString("N").ToLower() +
+ CoreHelpers.RandomString(3, upper: false, numeric: false);
+
+ var result = await braintreeGateway.Customer.CreateAsync(new CustomerRequest
+ {
+ Id = braintreeCustomerId,
+ CustomFields = new Dictionary
+ {
+ [subscriber.BraintreeIdField()] = subscriber.Id.ToString(),
+ [subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion
+ },
+ Email = subscriber.BillingEmailAddress(),
+ PaymentMethodNonce = token
+ });
+
+ return result.Target;
+ }
+
+ private async Task ReplaceBraintreePaymentMethodAsync(
+ Braintree.Customer customer,
+ string token)
+ {
+ var existing = customer.DefaultPaymentMethod;
+
+ var result = await braintreeGateway.PaymentMethod.CreateAsync(new PaymentMethodRequest
+ {
+ CustomerId = customer.Id,
+ PaymentMethodNonce = token
+ });
+
+ await braintreeGateway.Customer.UpdateAsync(
+ customer.Id,
+ new CustomerRequest { DefaultPaymentMethodToken = result.Target.Token });
+
+ if (existing != null)
+ {
+ await braintreeGateway.PaymentMethod.DeleteAsync(existing.Token);
+ }
+ }
+
+ private async Task UnlinkBraintreeCustomerAsync(
+ Customer customer)
+ {
+ if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId))
+ {
+ var metadata = new Dictionary
+ {
+ [StripeConstants.MetadataKeys.RetiredBraintreeCustomerId] = braintreeCustomerId,
+ [StripeConstants.MetadataKeys.BraintreeCustomerId] = string.Empty
+ };
+
+ await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata });
+ }
+ }
+}
diff --git a/src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs b/src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs
new file mode 100644
index 0000000000..1e9492b876
--- /dev/null
+++ b/src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs
@@ -0,0 +1,63 @@
+#nullable enable
+using Bit.Core.Billing.Caches;
+using Bit.Core.Billing.Commands;
+using Bit.Core.Billing.Payment.Models;
+using Bit.Core.Entities;
+using Bit.Core.Services;
+using Microsoft.Extensions.Logging;
+using Stripe;
+
+namespace Bit.Core.Billing.Payment.Commands;
+
+public interface IVerifyBankAccountCommand
+{
+ Task> Run(
+ ISubscriber subscriber,
+ string descriptorCode);
+}
+
+public class VerifyBankAccountCommand(
+ ILogger logger,
+ ISetupIntentCache setupIntentCache,
+ IStripeAdapter stripeAdapter) : BillingCommand(logger), IVerifyBankAccountCommand
+{
+ private readonly ILogger _logger = logger;
+
+ private static readonly Conflict _conflict =
+ new("We had a problem verifying your bank account. Please contact support for assistance.");
+
+ public Task> Run(
+ ISubscriber subscriber,
+ string descriptorCode) => HandleAsync(async () =>
+ {
+ var setupIntentId = await setupIntentCache.Get(subscriber.Id);
+
+ if (string.IsNullOrEmpty(setupIntentId))
+ {
+ _logger.LogError(
+ "{Command}: Could not find setup intent to verify subscriber's ({SubscriberID}) bank account",
+ CommandName, subscriber.Id);
+ return _conflict;
+ }
+
+ await stripeAdapter.SetupIntentVerifyMicroDeposit(setupIntentId,
+ new SetupIntentVerifyMicrodepositsOptions { DescriptorCode = descriptorCode });
+
+ var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId,
+ new SetupIntentGetOptions { Expand = ["payment_method"] });
+
+ var paymentMethod = await stripeAdapter.PaymentMethodAttachAsync(setupIntent.PaymentMethodId,
+ new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId });
+
+ await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId,
+ new CustomerUpdateOptions
+ {
+ InvoiceSettings = new CustomerInvoiceSettingsOptions
+ {
+ DefaultPaymentMethod = setupIntent.PaymentMethodId
+ }
+ });
+
+ return MaskedPaymentMethod.From(paymentMethod.UsBankAccount);
+ });
+}
diff --git a/src/Core/Billing/Payment/Models/BillingAddress.cs b/src/Core/Billing/Payment/Models/BillingAddress.cs
new file mode 100644
index 0000000000..5c2c43231c
--- /dev/null
+++ b/src/Core/Billing/Payment/Models/BillingAddress.cs
@@ -0,0 +1,30 @@
+#nullable enable
+using Stripe;
+
+namespace Bit.Core.Billing.Payment.Models;
+
+public record TaxID(string Code, string Value);
+
+public record BillingAddress
+{
+ public required string Country { get; set; }
+ public required string PostalCode { get; set; }
+ public string? Line1 { get; set; }
+ public string? Line2 { get; set; }
+ public string? City { get; set; }
+ public string? State { get; set; }
+ public TaxID? TaxId { get; set; }
+
+ public static BillingAddress From(Address address) => new()
+ {
+ Country = address.Country,
+ PostalCode = address.PostalCode,
+ Line1 = address.Line1,
+ Line2 = address.Line2,
+ City = address.City,
+ State = address.State
+ };
+
+ public static BillingAddress From(Address address, TaxId? taxId) =>
+ From(address) with { TaxId = taxId != null ? new TaxID(taxId.Type, taxId.Value) : null };
+}
diff --git a/src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs b/src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs
new file mode 100644
index 0000000000..c98fddc785
--- /dev/null
+++ b/src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs
@@ -0,0 +1,120 @@
+#nullable enable
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using Bit.Core.Billing.Pricing.JSON;
+using Braintree;
+using OneOf;
+using Stripe;
+
+namespace Bit.Core.Billing.Payment.Models;
+
+public record MaskedBankAccount
+{
+ public required string BankName { get; init; }
+ public required string Last4 { get; init; }
+ public required bool Verified { get; init; }
+ public string Type => "bankAccount";
+}
+
+public record MaskedCard
+{
+ public required string Brand { get; init; }
+ public required string Last4 { get; init; }
+ public required string Expiration { get; init; }
+ public string Type => "card";
+}
+
+public record MaskedPayPalAccount
+{
+ public required string Email { get; init; }
+ public string Type => "payPal";
+}
+
+[JsonConverter(typeof(MaskedPaymentMethodJsonConverter))]
+public class MaskedPaymentMethod(OneOf input)
+ : OneOfBase(input)
+{
+ public static implicit operator MaskedPaymentMethod(MaskedBankAccount bankAccount) => new(bankAccount);
+ public static implicit operator MaskedPaymentMethod(MaskedCard card) => new(card);
+ public static implicit operator MaskedPaymentMethod(MaskedPayPalAccount payPal) => new(payPal);
+
+ public static MaskedPaymentMethod From(BankAccount bankAccount) => new MaskedBankAccount
+ {
+ BankName = bankAccount.BankName,
+ Last4 = bankAccount.Last4,
+ Verified = bankAccount.Status == "verified"
+ };
+
+ public static MaskedPaymentMethod From(Card card) => new MaskedCard
+ {
+ Brand = card.Brand.ToLower(),
+ Last4 = card.Last4,
+ Expiration = $"{card.ExpMonth:00}/{card.ExpYear}"
+ };
+
+ public static MaskedPaymentMethod From(PaymentMethodCard card) => new MaskedCard
+ {
+ Brand = card.Brand.ToLower(),
+ Last4 = card.Last4,
+ Expiration = $"{card.ExpMonth:00}/{card.ExpYear}"
+ };
+
+ public static MaskedPaymentMethod From(SetupIntent setupIntent) => new MaskedBankAccount
+ {
+ BankName = setupIntent.PaymentMethod.UsBankAccount.BankName,
+ Last4 = setupIntent.PaymentMethod.UsBankAccount.Last4,
+ Verified = false
+ };
+
+ public static MaskedPaymentMethod From(SourceCard sourceCard) => new MaskedCard
+ {
+ Brand = sourceCard.Brand.ToLower(),
+ Last4 = sourceCard.Last4,
+ Expiration = $"{sourceCard.ExpMonth:00}/{sourceCard.ExpYear}"
+ };
+
+ public static MaskedPaymentMethod From(PaymentMethodUsBankAccount bankAccount) => new MaskedBankAccount
+ {
+ BankName = bankAccount.BankName,
+ Last4 = bankAccount.Last4,
+ Verified = true
+ };
+
+ public static MaskedPaymentMethod From(PayPalAccount payPalAccount) => new MaskedPayPalAccount { Email = payPalAccount.Email };
+}
+
+public class MaskedPaymentMethodJsonConverter : TypeReadingJsonConverter
+{
+ protected override string TypePropertyName => nameof(MaskedBankAccount.Type).ToLower();
+
+ public override MaskedPaymentMethod? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ var type = ReadType(reader);
+
+ return type switch
+ {
+ "bankAccount" => JsonSerializer.Deserialize(ref reader, options) switch
+ {
+ null => null,
+ var bankAccount => new MaskedPaymentMethod(bankAccount)
+ },
+ "card" => JsonSerializer.Deserialize(ref reader, options) switch
+ {
+ null => null,
+ var card => new MaskedPaymentMethod(card)
+ },
+ "payPal" => JsonSerializer.Deserialize(ref reader, options) switch
+ {
+ null => null,
+ var payPal => new MaskedPaymentMethod(payPal)
+ },
+ _ => Skip(ref reader)
+ };
+ }
+
+ public override void Write(Utf8JsonWriter writer, MaskedPaymentMethod value, JsonSerializerOptions options)
+ => value.Switch(
+ bankAccount => JsonSerializer.Serialize(writer, bankAccount, options),
+ card => JsonSerializer.Serialize(writer, card, options),
+ payPal => JsonSerializer.Serialize(writer, payPal, options));
+}
diff --git a/src/Core/Billing/Payment/Models/ProductUsageType.cs b/src/Core/Billing/Payment/Models/ProductUsageType.cs
new file mode 100644
index 0000000000..2ecd1233c6
--- /dev/null
+++ b/src/Core/Billing/Payment/Models/ProductUsageType.cs
@@ -0,0 +1,7 @@
+namespace Bit.Core.Billing.Payment.Models;
+
+public enum ProductUsageType
+{
+ Personal,
+ Business
+}
diff --git a/src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs b/src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs
new file mode 100644
index 0000000000..d27a924360
--- /dev/null
+++ b/src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs
@@ -0,0 +1,8 @@
+namespace Bit.Core.Billing.Payment.Models;
+
+public enum TokenizablePaymentMethodType
+{
+ BankAccount,
+ Card,
+ PayPal
+}
diff --git a/src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs b/src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs
new file mode 100644
index 0000000000..edbf1bb121
--- /dev/null
+++ b/src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs
@@ -0,0 +1,8 @@
+#nullable enable
+namespace Bit.Core.Billing.Payment.Models;
+
+public record TokenizedPaymentMethod
+{
+ public required TokenizablePaymentMethodType Type { get; set; }
+ public required string Token { get; set; }
+}
diff --git a/src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs b/src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs
new file mode 100644
index 0000000000..84d4d4f377
--- /dev/null
+++ b/src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs
@@ -0,0 +1,41 @@
+#nullable enable
+using Bit.Core.Billing.Extensions;
+using Bit.Core.Billing.Payment.Models;
+using Bit.Core.Billing.Services;
+using Bit.Core.Entities;
+using Stripe;
+
+namespace Bit.Core.Billing.Payment.Queries;
+
+public interface IGetBillingAddressQuery
+{
+ Task Run(ISubscriber subscriber);
+}
+
+public class GetBillingAddressQuery(
+ ISubscriberService subscriberService) : IGetBillingAddressQuery
+{
+ public async Task Run(ISubscriber subscriber)
+ {
+ var productUsageType = subscriber.GetProductUsageType();
+
+ var options = productUsageType switch
+ {
+ ProductUsageType.Business => new CustomerGetOptions { Expand = ["tax_ids"] },
+ _ => new CustomerGetOptions()
+ };
+
+ var customer = await subscriberService.GetCustomer(subscriber, options);
+
+ if (customer is not { Address: { Country: not null, PostalCode: not null } })
+ {
+ return null;
+ }
+
+ var taxId = productUsageType == ProductUsageType.Business ? customer.TaxIds?.FirstOrDefault() : null;
+
+ return taxId != null
+ ? BillingAddress.From(customer.Address, taxId)
+ : BillingAddress.From(customer.Address);
+ }
+}
diff --git a/src/Core/Billing/Payment/Queries/GetCreditQuery.cs b/src/Core/Billing/Payment/Queries/GetCreditQuery.cs
new file mode 100644
index 0000000000..79c9a13aba
--- /dev/null
+++ b/src/Core/Billing/Payment/Queries/GetCreditQuery.cs
@@ -0,0 +1,26 @@
+#nullable enable
+using Bit.Core.Billing.Services;
+using Bit.Core.Entities;
+
+namespace Bit.Core.Billing.Payment.Queries;
+
+public interface IGetCreditQuery
+{
+ Task Run(ISubscriber subscriber);
+}
+
+public class GetCreditQuery(
+ ISubscriberService subscriberService) : IGetCreditQuery
+{
+ public async Task Run(ISubscriber subscriber)
+ {
+ var customer = await subscriberService.GetCustomer(subscriber);
+
+ if (customer == null)
+ {
+ return null;
+ }
+
+ return Convert.ToDecimal(customer.Balance) * -1 / 100;
+ }
+}
diff --git a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs
new file mode 100644
index 0000000000..eb42a8c78a
--- /dev/null
+++ b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs
@@ -0,0 +1,96 @@
+#nullable enable
+using Bit.Core.Billing.Caches;
+using Bit.Core.Billing.Constants;
+using Bit.Core.Billing.Extensions;
+using Bit.Core.Billing.Payment.Models;
+using Bit.Core.Billing.Services;
+using Bit.Core.Entities;
+using Bit.Core.Services;
+using Braintree;
+using Microsoft.Extensions.Logging;
+using Stripe;
+
+namespace Bit.Core.Billing.Payment.Queries;
+
+public interface IGetPaymentMethodQuery
+{
+ Task Run(ISubscriber subscriber);
+}
+
+public class GetPaymentMethodQuery(
+ IBraintreeGateway braintreeGateway,
+ ILogger logger,
+ ISetupIntentCache setupIntentCache,
+ IStripeAdapter stripeAdapter,
+ ISubscriberService subscriberService) : IGetPaymentMethodQuery
+{
+ public async Task Run(ISubscriber subscriber)
+ {
+ var customer = await subscriberService.GetCustomer(subscriber,
+ new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] });
+
+ if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId))
+ {
+ var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId);
+
+ if (braintreeCustomer.DefaultPaymentMethod is PayPalAccount payPalAccount)
+ {
+ return new MaskedPayPalAccount { Email = payPalAccount.Email };
+ }
+
+ logger.LogWarning("Subscriber ({SubscriberID}) has a linked Braintree customer ({BraintreeCustomerId}) with no PayPal account.", subscriber.Id, braintreeCustomerId);
+
+ return null;
+ }
+
+ var paymentMethod = customer.InvoiceSettings.DefaultPaymentMethod != null
+ ? customer.InvoiceSettings.DefaultPaymentMethod.Type switch
+ {
+ "card" => MaskedPaymentMethod.From(customer.InvoiceSettings.DefaultPaymentMethod.Card),
+ "us_bank_account" => MaskedPaymentMethod.From(customer.InvoiceSettings.DefaultPaymentMethod.UsBankAccount),
+ _ => null
+ }
+ : null;
+
+ if (paymentMethod != null)
+ {
+ return paymentMethod;
+ }
+
+ if (customer.DefaultSource != null)
+ {
+ paymentMethod = customer.DefaultSource switch
+ {
+ Card card => MaskedPaymentMethod.From(card),
+ BankAccount bankAccount => MaskedPaymentMethod.From(bankAccount),
+ Source { Card: not null } source => MaskedPaymentMethod.From(source.Card),
+ _ => null
+ };
+
+ if (paymentMethod != null)
+ {
+ return paymentMethod;
+ }
+ }
+
+ var setupIntentId = await setupIntentCache.Get(subscriber.Id);
+
+ if (string.IsNullOrEmpty(setupIntentId))
+ {
+ return null;
+ }
+
+ var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions
+ {
+ Expand = ["payment_method"]
+ });
+
+ // ReSharper disable once ConvertIfStatementToReturnStatement
+ if (!setupIntent.IsUnverifiedBankAccount())
+ {
+ return null;
+ }
+
+ return MaskedPaymentMethod.From(setupIntent);
+ }
+}
diff --git a/src/Core/Billing/Payment/Registrations.cs b/src/Core/Billing/Payment/Registrations.cs
new file mode 100644
index 0000000000..1cc7914f10
--- /dev/null
+++ b/src/Core/Billing/Payment/Registrations.cs
@@ -0,0 +1,24 @@
+using Bit.Core.Billing.Payment.Clients;
+using Bit.Core.Billing.Payment.Commands;
+using Bit.Core.Billing.Payment.Queries;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Bit.Core.Billing.Payment;
+
+public static class Registrations
+{
+ public static void AddPaymentOperations(this IServiceCollection services)
+ {
+ // Commands
+ services.AddTransient();
+ services.AddTransient();
+ services.AddTransient();
+ services.AddTransient();
+ services.AddTransient();
+
+ // Queries
+ services.AddTransient();
+ services.AddTransient();
+ services.AddTransient();
+ }
+}
diff --git a/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs b/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs
index ef8d33304e..05beccdb60 100644
--- a/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs
+++ b/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs
@@ -6,7 +6,7 @@ namespace Bit.Core.Billing.Pricing.JSON;
#nullable enable
-public abstract class TypeReadingJsonConverter : JsonConverter
+public abstract class TypeReadingJsonConverter : JsonConverter where T : class
{
protected virtual string TypePropertyName => nameof(ScalableDTO.Type).ToLower();
@@ -14,7 +14,9 @@ public abstract class TypeReadingJsonConverter : JsonConverter
{
while (reader.Read())
{
- if (reader.TokenType != JsonTokenType.PropertyName || reader.GetString()?.ToLower() != TypePropertyName)
+ if (reader.CurrentDepth != 1 ||
+ reader.TokenType != JsonTokenType.PropertyName ||
+ reader.GetString()?.ToLower() != TypePropertyName)
{
continue;
}
@@ -25,4 +27,10 @@ public abstract class TypeReadingJsonConverter : JsonConverter
return null;
}
+
+ protected T? Skip(ref Utf8JsonReader reader)
+ {
+ reader.Skip();
+ return null;
+ }
}
diff --git a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs b/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs
index c777d0c0d1..86f233232f 100644
--- a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs
+++ b/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs
@@ -1,8 +1,8 @@
#nullable enable
+using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
-using Bit.Core.Billing.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Services;
@@ -20,111 +20,95 @@ public class PreviewTaxAmountCommand(
ILogger logger,
IPricingClient pricingClient,
IStripeAdapter stripeAdapter,
- ITaxService taxService) : IPreviewTaxAmountCommand
+ ITaxService taxService) : BillingCommand(logger), IPreviewTaxAmountCommand
{
- public async Task> Run(OrganizationTrialParameters parameters)
- {
- var (planType, productType, taxInformation) = parameters;
-
- var plan = await pricingClient.GetPlanOrThrow(planType);
-
- var options = new InvoiceCreatePreviewOptions
+ public Task> Run(OrganizationTrialParameters parameters)
+ => HandleAsync(async () =>
{
- Currency = "usd",
- CustomerDetails = new InvoiceCustomerDetailsOptions
+ var (planType, productType, taxInformation) = parameters;
+
+ var plan = await pricingClient.GetPlanOrThrow(planType);
+
+ var options = new InvoiceCreatePreviewOptions
{
- Address = new AddressOptions
+ Currency = "usd",
+ CustomerDetails = new InvoiceCustomerDetailsOptions
{
- Country = taxInformation.Country,
- PostalCode = taxInformation.PostalCode
- }
- },
- SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
- {
- Items = [
- new InvoiceSubscriptionDetailsItemOptions
+ Address = new AddressOptions
{
- Price = plan.HasNonSeatBasedPasswordManagerPlan() ? plan.PasswordManager.StripePlanId : plan.PasswordManager.StripeSeatPlanId,
- Quantity = 1
+ Country = taxInformation.Country,
+ PostalCode = taxInformation.PostalCode
}
- ]
- }
- };
-
- if (productType == ProductType.SecretsManager)
- {
- options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions
- {
- Price = plan.SecretsManager.StripeSeatPlanId,
- Quantity = 1
- });
-
- options.Coupon = StripeConstants.CouponIDs.SecretsManagerStandalone;
- }
-
- if (!string.IsNullOrEmpty(taxInformation.TaxId))
- {
- var taxIdType = taxService.GetStripeTaxCode(
- taxInformation.Country,
- taxInformation.TaxId);
-
- if (string.IsNullOrEmpty(taxIdType))
- {
- return BadRequest.UnknownTaxIdType;
- }
-
- options.CustomerDetails.TaxIds = [
- new InvoiceCustomerDetailsTaxIdOptions
+ },
+ SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
- Type = taxIdType,
- Value = taxInformation.TaxId
+ Items =
+ [
+ new InvoiceSubscriptionDetailsItemOptions
+ {
+ Price = plan.HasNonSeatBasedPasswordManagerPlan()
+ ? plan.PasswordManager.StripePlanId
+ : plan.PasswordManager.StripeSeatPlanId,
+ Quantity = 1
+ }
+ ]
}
- ];
-
- if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
- {
- options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions
- {
- Type = StripeConstants.TaxIdType.EUVAT,
- Value = $"ES{parameters.TaxInformation.TaxId}"
- });
- }
- }
-
- if (planType.GetProductTier() == ProductTierType.Families)
- {
- options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true };
- }
- else
- {
- options.AutomaticTax = new InvoiceAutomaticTaxOptions
- {
- Enabled = options.CustomerDetails.Address.Country == "US" ||
- options.CustomerDetails.TaxIds is [_, ..]
};
- }
- try
- {
+ if (productType == ProductType.SecretsManager)
+ {
+ options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions
+ {
+ Price = plan.SecretsManager.StripeSeatPlanId,
+ Quantity = 1
+ });
+
+ options.Coupon = StripeConstants.CouponIDs.SecretsManagerStandalone;
+ }
+
+ if (!string.IsNullOrEmpty(taxInformation.TaxId))
+ {
+ var taxIdType = taxService.GetStripeTaxCode(
+ taxInformation.Country,
+ taxInformation.TaxId);
+
+ if (string.IsNullOrEmpty(taxIdType))
+ {
+ return new BadRequest(
+ "We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support for assistance.");
+ }
+
+ options.CustomerDetails.TaxIds =
+ [
+ new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = taxInformation.TaxId }
+ ];
+
+ if (taxIdType == StripeConstants.TaxIdType.SpanishNIF)
+ {
+ options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions
+ {
+ Type = StripeConstants.TaxIdType.EUVAT,
+ Value = $"ES{parameters.TaxInformation.TaxId}"
+ });
+ }
+ }
+
+ if (planType.GetProductTier() == ProductTierType.Families)
+ {
+ options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true };
+ }
+ else
+ {
+ options.AutomaticTax = new InvoiceAutomaticTaxOptions
+ {
+ Enabled = options.CustomerDetails.Address.Country == "US" ||
+ options.CustomerDetails.TaxIds is [_, ..]
+ };
+ }
+
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return Convert.ToDecimal(invoice.Tax) / 100;
- }
- catch (StripeException stripeException) when (stripeException.StripeError.Code ==
- StripeConstants.ErrorCodes.CustomerTaxLocationInvalid)
- {
- return BadRequest.TaxLocationInvalid;
- }
- catch (StripeException stripeException) when (stripeException.StripeError.Code ==
- StripeConstants.ErrorCodes.TaxIdInvalid)
- {
- return BadRequest.TaxIdNumberInvalid;
- }
- catch (StripeException stripeException)
- {
- logger.LogError(stripeException, "Stripe responded with an error during {Operation}. Code: {Code}", nameof(PreviewTaxAmountCommand), stripeException.StripeError.Code);
- return new Unhandled();
- }
- }
+ });
}
#region Command Parameters
diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs
index 2a3b619de6..5c0bd3216e 100644
--- a/src/Core/Constants.cs
+++ b/src/Core/Constants.cs
@@ -155,6 +155,7 @@ public static class FeatureFlagKeys
public const string PM20322_AllowTrialLength0 = "pm-20322-allow-trial-length-0";
public const string PM21092_SetNonUSBusinessUseToReverseCharge = "pm-21092-set-non-us-business-use-to-reverse-charge";
public const string PM21383_GetProviderPriceFromStripe = "pm-21383-get-provider-price-from-stripe";
+ public const string PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout";
/* Data Insights and Reporting Team */
public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application";
diff --git a/test/Api.Test/Billing/Attributes/InjectOrganizationAttributeTests.cs b/test/Api.Test/Billing/Attributes/InjectOrganizationAttributeTests.cs
new file mode 100644
index 0000000000..252c457924
--- /dev/null
+++ b/test/Api.Test/Billing/Attributes/InjectOrganizationAttributeTests.cs
@@ -0,0 +1,132 @@
+using Bit.Api.Billing.Attributes;
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.Models.Api;
+using Bit.Core.Repositories;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.AspNetCore.Mvc.Abstractions;
+using Microsoft.AspNetCore.Mvc.Filters;
+using Microsoft.AspNetCore.Mvc.ModelBinding;
+using Microsoft.AspNetCore.Routing;
+using Microsoft.Extensions.DependencyInjection;
+using NSubstitute;
+using Xunit;
+
+namespace Bit.Api.Test.Billing.Attributes;
+
+public class InjectOrganizationAttributeTests
+{
+ private readonly IOrganizationRepository _organizationRepository;
+ private readonly ActionExecutionDelegate _next;
+ private readonly ActionExecutingContext _context;
+ private readonly Organization _organization;
+ private readonly Guid _organizationId;
+
+ public InjectOrganizationAttributeTests()
+ {
+ _organizationRepository = Substitute.For();
+ _organizationId = Guid.NewGuid();
+ _organization = new Organization { Id = _organizationId };
+
+ var httpContext = new DefaultHttpContext();
+ var services = new ServiceCollection();
+ services.AddScoped(_ => _organizationRepository);
+ httpContext.RequestServices = services.BuildServiceProvider();
+
+ var routeData = new RouteData { Values = { ["organizationId"] = _organizationId.ToString() } };
+
+ var actionContext = new ActionContext(
+ httpContext,
+ routeData,
+ new ActionDescriptor(),
+ new ModelStateDictionary()
+ );
+
+ _next = () => Task.FromResult(new ActionExecutedContext(
+ actionContext,
+ new List(),
+ new object()));
+
+ _context = new ActionExecutingContext(
+ actionContext,
+ new List(),
+ new Dictionary(),
+ new object());
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithExistingOrganization_InjectsOrganization()
+ {
+ var attribute = new InjectOrganizationAttribute();
+ _organizationRepository.GetByIdAsync(_organizationId)
+ .Returns(_organization);
+
+ var parameter = new ParameterDescriptor
+ {
+ Name = "organization",
+ ParameterType = typeof(Organization)
+ };
+ _context.ActionDescriptor.Parameters = [parameter];
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.Equal(_organization, _context.ActionArguments["organization"]);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithNonExistentOrganization_ReturnsNotFound()
+ {
+ var attribute = new InjectOrganizationAttribute();
+ _organizationRepository.GetByIdAsync(_organizationId)
+ .Returns((Organization)null);
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.IsType(_context.Result);
+ var result = (NotFoundObjectResult)_context.Result;
+ Assert.IsType(result.Value);
+ Assert.Equal("Organization not found.", ((ErrorResponseModel)result.Value).Message);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithInvalidOrganizationId_ReturnsBadRequest()
+ {
+ var attribute = new InjectOrganizationAttribute();
+ _context.RouteData.Values["organizationId"] = "not-a-guid";
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.IsType(_context.Result);
+ var result = (BadRequestObjectResult)_context.Result;
+ Assert.IsType(result.Value);
+ Assert.Equal("Route parameter 'organizationId' is missing or invalid.", ((ErrorResponseModel)result.Value).Message);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithMissingOrganizationId_ReturnsBadRequest()
+ {
+ var attribute = new InjectOrganizationAttribute();
+ _context.RouteData.Values.Clear();
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.IsType(_context.Result);
+ var result = (BadRequestObjectResult)_context.Result;
+ Assert.IsType(result.Value);
+ Assert.Equal("Route parameter 'organizationId' is missing or invalid.", ((ErrorResponseModel)result.Value).Message);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithoutOrganizationParameter_ContinuesExecution()
+ {
+ var attribute = new InjectOrganizationAttribute();
+ _organizationRepository.GetByIdAsync(_organizationId)
+ .Returns(_organization);
+
+ _context.ActionDescriptor.Parameters = Array.Empty();
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.Empty(_context.ActionArguments);
+ }
+}
diff --git a/test/Api.Test/Billing/Attributes/InjectProviderAttributeTests.cs b/test/Api.Test/Billing/Attributes/InjectProviderAttributeTests.cs
new file mode 100644
index 0000000000..0a3e19f8b1
--- /dev/null
+++ b/test/Api.Test/Billing/Attributes/InjectProviderAttributeTests.cs
@@ -0,0 +1,190 @@
+using Bit.Api.Billing.Attributes;
+using Bit.Api.Models.Public.Response;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.AdminConsole.Enums.Provider;
+using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.Context;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.AspNetCore.Mvc.Abstractions;
+using Microsoft.AspNetCore.Mvc.Filters;
+using Microsoft.AspNetCore.Mvc.ModelBinding;
+using Microsoft.AspNetCore.Routing;
+using Microsoft.Extensions.DependencyInjection;
+using NSubstitute;
+using Xunit;
+
+namespace Bit.Api.Test.Billing.Attributes;
+
+public class InjectProviderAttributeTests
+{
+ private readonly IProviderRepository _providerRepository;
+ private readonly ICurrentContext _currentContext;
+ private readonly ActionExecutionDelegate _next;
+ private readonly ActionExecutingContext _context;
+ private readonly Provider _provider;
+ private readonly Guid _providerId;
+
+ public InjectProviderAttributeTests()
+ {
+ _providerRepository = Substitute.For();
+ _currentContext = Substitute.For();
+ _providerId = Guid.NewGuid();
+ _provider = new Provider { Id = _providerId };
+
+ var httpContext = new DefaultHttpContext();
+ var services = new ServiceCollection();
+ services.AddScoped(_ => _providerRepository);
+ services.AddScoped(_ => _currentContext);
+ httpContext.RequestServices = services.BuildServiceProvider();
+
+ var routeData = new RouteData { Values = { ["providerId"] = _providerId.ToString() } };
+
+ var actionContext = new ActionContext(
+ httpContext,
+ routeData,
+ new ActionDescriptor(),
+ new ModelStateDictionary()
+ );
+
+ _next = () => Task.FromResult(new ActionExecutedContext(
+ actionContext,
+ new List(),
+ new object()));
+
+ _context = new ActionExecutingContext(
+ actionContext,
+ new List(),
+ new Dictionary(),
+ new object());
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithExistingProvider_InjectsProvider()
+ {
+ var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin);
+ _providerRepository.GetByIdAsync(_providerId).Returns(_provider);
+ _currentContext.ProviderProviderAdmin(_providerId).Returns(true);
+
+ var parameter = new ParameterDescriptor
+ {
+ Name = "provider",
+ ParameterType = typeof(Provider)
+ };
+ _context.ActionDescriptor.Parameters = [parameter];
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.Equal(_provider, _context.ActionArguments["provider"]);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithNonExistentProvider_ReturnsNotFound()
+ {
+ var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin);
+ _providerRepository.GetByIdAsync(_providerId).Returns((Provider)null);
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.IsType(_context.Result);
+ var result = (NotFoundObjectResult)_context.Result;
+ Assert.IsType(result.Value);
+ Assert.Equal("Provider not found.", ((ErrorResponseModel)result.Value).Message);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithInvalidProviderId_ReturnsBadRequest()
+ {
+ var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin);
+ _context.RouteData.Values["providerId"] = "not-a-guid";
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.IsType(_context.Result);
+ var result = (BadRequestObjectResult)_context.Result;
+ Assert.IsType(result.Value);
+ Assert.Equal("Route parameter 'providerId' is missing or invalid.", ((ErrorResponseModel)result.Value).Message);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithMissingProviderId_ReturnsBadRequest()
+ {
+ var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin);
+ _context.RouteData.Values.Clear();
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.IsType(_context.Result);
+ var result = (BadRequestObjectResult)_context.Result;
+ Assert.IsType(result.Value);
+ Assert.Equal("Route parameter 'providerId' is missing or invalid.", ((ErrorResponseModel)result.Value).Message);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithoutProviderParameter_ContinuesExecution()
+ {
+ var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin);
+ _providerRepository.GetByIdAsync(_providerId).Returns(_provider);
+ _currentContext.ProviderProviderAdmin(_providerId).Returns(true);
+
+ _context.ActionDescriptor.Parameters = Array.Empty();
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.Empty(_context.ActionArguments);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_UnauthorizedProviderAdmin_ReturnsUnauthorized()
+ {
+ var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin);
+ _providerRepository.GetByIdAsync(_providerId).Returns(_provider);
+ _currentContext.ProviderProviderAdmin(_providerId).Returns(false);
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.IsType(_context.Result);
+ var result = (UnauthorizedObjectResult)_context.Result;
+ Assert.IsType(result.Value);
+ Assert.Equal("Unauthorized.", ((ErrorResponseModel)result.Value).Message);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_UnauthorizedServiceUser_ReturnsUnauthorized()
+ {
+ var attribute = new InjectProviderAttribute(ProviderUserType.ServiceUser);
+ _providerRepository.GetByIdAsync(_providerId).Returns(_provider);
+ _currentContext.ProviderUser(_providerId).Returns(false);
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.IsType(_context.Result);
+ var result = (UnauthorizedObjectResult)_context.Result;
+ Assert.IsType(result.Value);
+ Assert.Equal("Unauthorized.", ((ErrorResponseModel)result.Value).Message);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_AuthorizedProviderAdmin_Succeeds()
+ {
+ var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin);
+ _providerRepository.GetByIdAsync(_providerId).Returns(_provider);
+ _currentContext.ProviderProviderAdmin(_providerId).Returns(true);
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.Null(_context.Result);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_AuthorizedServiceUser_Succeeds()
+ {
+ var attribute = new InjectProviderAttribute(ProviderUserType.ServiceUser);
+ _providerRepository.GetByIdAsync(_providerId).Returns(_provider);
+ _currentContext.ProviderUser(_providerId).Returns(true);
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.Null(_context.Result);
+ }
+}
diff --git a/test/Api.Test/Billing/Attributes/InjectUserAttributesTests.cs b/test/Api.Test/Billing/Attributes/InjectUserAttributesTests.cs
new file mode 100644
index 0000000000..5c26cca64a
--- /dev/null
+++ b/test/Api.Test/Billing/Attributes/InjectUserAttributesTests.cs
@@ -0,0 +1,129 @@
+using System.Security.Claims;
+using Bit.Api.Billing.Attributes;
+using Bit.Core.Entities;
+using Bit.Core.Models.Api;
+using Bit.Core.Services;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.AspNetCore.Mvc.Abstractions;
+using Microsoft.AspNetCore.Mvc.Filters;
+using Microsoft.AspNetCore.Mvc.ModelBinding;
+using Microsoft.AspNetCore.Routing;
+using Microsoft.Extensions.DependencyInjection;
+using NSubstitute;
+using Xunit;
+
+namespace Bit.Api.Test.Billing.Attributes;
+
+public class InjectUserAttributesTests
+{
+ private readonly IUserService _userService;
+ private readonly ActionExecutionDelegate _next;
+ private readonly ActionExecutingContext _context;
+ private readonly User _user;
+
+ public InjectUserAttributesTests()
+ {
+ _userService = Substitute.For();
+ _user = new User { Id = Guid.NewGuid() };
+
+ var httpContext = new DefaultHttpContext();
+ var services = new ServiceCollection();
+ services.AddScoped(_ => _userService);
+ httpContext.RequestServices = services.BuildServiceProvider();
+
+ var actionContext = new ActionContext(
+ httpContext,
+ new RouteData(),
+ new ActionDescriptor(),
+ new ModelStateDictionary()
+ );
+
+ _next = () => Task.FromResult(new ActionExecutedContext(
+ actionContext,
+ new List(),
+ new object()));
+
+ _context = new ActionExecutingContext(
+ actionContext,
+ new List(),
+ new Dictionary(),
+ new object());
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithAuthorizedUser_InjectsUser()
+ {
+ var attribute = new InjectUserAttribute();
+ _userService.GetUserByPrincipalAsync(Arg.Any())
+ .Returns(_user);
+
+ var parameter = new ParameterDescriptor
+ {
+ Name = "user",
+ ParameterType = typeof(User)
+ };
+ _context.ActionDescriptor.Parameters = [parameter];
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.Equal(_user, _context.ActionArguments["user"]);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithUnauthorizedUser_ReturnsUnauthorized()
+ {
+ var attribute = new InjectUserAttribute();
+ _userService.GetUserByPrincipalAsync(Arg.Any())
+ .Returns((User)null);
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.IsType(_context.Result);
+ var result = (UnauthorizedObjectResult)_context.Result;
+ Assert.IsType(result.Value);
+ Assert.Equal("Unauthorized.", ((ErrorResponseModel)result.Value).Message);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithoutUserParameter_ContinuesExecution()
+ {
+ var attribute = new InjectUserAttribute();
+ _userService.GetUserByPrincipalAsync(Arg.Any())
+ .Returns(_user);
+
+ _context.ActionDescriptor.Parameters = Array.Empty();
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.Empty(_context.ActionArguments);
+ }
+
+ [Fact]
+ public async Task OnActionExecutionAsync_WithMultipleParameters_InjectsUserCorrectly()
+ {
+ var attribute = new InjectUserAttribute();
+ _userService.GetUserByPrincipalAsync(Arg.Any())
+ .Returns(_user);
+
+ var parameters = new[]
+ {
+ new ParameterDescriptor
+ {
+ Name = "otherParam",
+ ParameterType = typeof(string)
+ },
+ new ParameterDescriptor
+ {
+ Name = "user",
+ ParameterType = typeof(User)
+ }
+ };
+ _context.ActionDescriptor.Parameters = parameters;
+
+ await attribute.OnActionExecutionAsync(_context, _next);
+
+ Assert.Single(_context.ActionArguments);
+ Assert.Equal(_user, _context.ActionArguments["user"]);
+ }
+}
diff --git a/test/Core.Test/Billing/Extensions/StripeExtensions.cs b/test/Core.Test/Billing/Extensions/StripeExtensions.cs
new file mode 100644
index 0000000000..44948bbfed
--- /dev/null
+++ b/test/Core.Test/Billing/Extensions/StripeExtensions.cs
@@ -0,0 +1,18 @@
+using Bit.Core.Billing.Payment.Models;
+using Stripe;
+
+namespace Bit.Core.Test.Billing.Extensions;
+
+public static class StripeExtensions
+{
+ public static bool HasExpansions(this BaseOptions options, params string[] expansions)
+ => expansions.All(expansion => options.Expand.Contains(expansion));
+
+ public static bool Matches(this AddressOptions address, BillingAddress billingAddress) =>
+ address.Country == billingAddress.Country &&
+ address.PostalCode == billingAddress.PostalCode &&
+ address.Line1 == billingAddress.Line1 &&
+ address.Line2 == billingAddress.Line2 &&
+ address.City == billingAddress.City &&
+ address.State == billingAddress.State;
+}
diff --git a/test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs
new file mode 100644
index 0000000000..800c3ec3ae
--- /dev/null
+++ b/test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs
@@ -0,0 +1,94 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.Billing.Payment.Clients;
+using Bit.Core.Billing.Payment.Commands;
+using Bit.Core.Entities;
+using Bit.Core.Settings;
+using Microsoft.Extensions.Logging;
+using NSubstitute;
+using Xunit;
+using Invoice = BitPayLight.Models.Invoice.Invoice;
+
+namespace Bit.Core.Test.Billing.Payment.Commands;
+
+public class CreateBitPayInvoiceForCreditCommandTests
+{
+ private readonly IBitPayClient _bitPayClient = Substitute.For();
+ private readonly GlobalSettings _globalSettings = new()
+ {
+ BitPay = new GlobalSettings.BitPaySettings { NotificationUrl = "https://example.com/bitpay/notification" }
+ };
+ private const string _redirectUrl = "https://bitwarden.com/redirect";
+ private readonly CreateBitPayInvoiceForCreditCommand _command;
+
+ public CreateBitPayInvoiceForCreditCommandTests()
+ {
+ _command = new CreateBitPayInvoiceForCreditCommand(
+ _bitPayClient,
+ _globalSettings,
+ Substitute.For>());
+ }
+
+ [Fact]
+ public async Task Run_User_CreatesInvoice_ReturnsInvoiceUrl()
+ {
+ var user = new User { Id = Guid.NewGuid(), Email = "user@gmail.com" };
+
+ _bitPayClient.CreateInvoice(Arg.Is(options =>
+ options.Buyer.Email == user.Email &&
+ options.Buyer.Name == user.Email &&
+ options.NotificationUrl == _globalSettings.BitPay.NotificationUrl &&
+ options.PosData == $"userId:{user.Id},accountCredit:1" &&
+ // ReSharper disable once CompareOfFloatsByEqualityOperator
+ options.Price == Convert.ToDouble(10M) &&
+ options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" });
+
+ var result = await _command.Run(user, 10M, _redirectUrl);
+
+ Assert.True(result.IsT0);
+ var invoiceUrl = result.AsT0;
+ Assert.Equal("https://bitpay.com/invoice/123", invoiceUrl);
+ }
+
+ [Fact]
+ public async Task Run_Organization_CreatesInvoice_ReturnsInvoiceUrl()
+ {
+ var organization = new Organization { Id = Guid.NewGuid(), BillingEmail = "organization@example.com", Name = "Organization" };
+
+ _bitPayClient.CreateInvoice(Arg.Is(options =>
+ options.Buyer.Email == organization.BillingEmail &&
+ options.Buyer.Name == organization.Name &&
+ options.NotificationUrl == _globalSettings.BitPay.NotificationUrl &&
+ options.PosData == $"organizationId:{organization.Id},accountCredit:1" &&
+ // ReSharper disable once CompareOfFloatsByEqualityOperator
+ options.Price == Convert.ToDouble(10M) &&
+ options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" });
+
+ var result = await _command.Run(organization, 10M, _redirectUrl);
+
+ Assert.True(result.IsT0);
+ var invoiceUrl = result.AsT0;
+ Assert.Equal("https://bitpay.com/invoice/123", invoiceUrl);
+ }
+
+ [Fact]
+ public async Task Run_Provider_CreatesInvoice_ReturnsInvoiceUrl()
+ {
+ var provider = new Provider { Id = Guid.NewGuid(), BillingEmail = "organization@example.com", Name = "Provider" };
+
+ _bitPayClient.CreateInvoice(Arg.Is(options =>
+ options.Buyer.Email == provider.BillingEmail &&
+ options.Buyer.Name == provider.Name &&
+ options.NotificationUrl == _globalSettings.BitPay.NotificationUrl &&
+ options.PosData == $"providerId:{provider.Id},accountCredit:1" &&
+ // ReSharper disable once CompareOfFloatsByEqualityOperator
+ options.Price == Convert.ToDouble(10M) &&
+ options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" });
+
+ var result = await _command.Run(provider, 10M, _redirectUrl);
+
+ Assert.True(result.IsT0);
+ var invoiceUrl = result.AsT0;
+ Assert.Equal("https://bitpay.com/invoice/123", invoiceUrl);
+ }
+}
diff --git a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs
new file mode 100644
index 0000000000..453d0c78e9
--- /dev/null
+++ b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs
@@ -0,0 +1,349 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.Billing.Constants;
+using Bit.Core.Billing.Enums;
+using Bit.Core.Billing.Payment.Commands;
+using Bit.Core.Billing.Payment.Models;
+using Bit.Core.Services;
+using Bit.Core.Test.Billing.Extensions;
+using Microsoft.Extensions.Logging;
+using NSubstitute;
+using Stripe;
+using Xunit;
+
+namespace Bit.Core.Test.Billing.Payment.Commands;
+
+using static StripeConstants;
+
+public class UpdateBillingAddressCommandTests
+{
+ private readonly IStripeAdapter _stripeAdapter;
+ private readonly UpdateBillingAddressCommand _command;
+
+ public UpdateBillingAddressCommandTests()
+ {
+ _stripeAdapter = Substitute.For();
+ _command = new UpdateBillingAddressCommand(
+ Substitute.For>(),
+ _stripeAdapter);
+ }
+
+ [Fact]
+ public async Task Run_PersonalOrganization_MakesCorrectInvocations_ReturnsBillingAddress()
+ {
+ var organization = new Organization
+ {
+ PlanType = PlanType.FamiliesAnnually,
+ GatewayCustomerId = "cus_123",
+ GatewaySubscriptionId = "sub_123"
+ };
+
+ var input = new BillingAddress
+ {
+ Country = "US",
+ PostalCode = "12345",
+ Line1 = "123 Main St.",
+ Line2 = "Suite 100",
+ City = "New York",
+ State = "NY"
+ };
+
+ var customer = new Customer
+ {
+ Address = new Address
+ {
+ Country = "US",
+ PostalCode = "12345",
+ Line1 = "123 Main St.",
+ Line2 = "Suite 100",
+ City = "New York",
+ State = "NY"
+ },
+ Subscriptions = new StripeList
+ {
+ Data =
+ [
+ new Subscription
+ {
+ Id = organization.GatewaySubscriptionId,
+ AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
+ }
+ ]
+ }
+ };
+
+ _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options =>
+ options.Address.Matches(input) &&
+ options.HasExpansions("subscriptions")
+ )).Returns(customer);
+
+ var result = await _command.Run(organization, input);
+
+ Assert.True(result.IsT0);
+ var output = result.AsT0;
+ Assert.Equivalent(input, output);
+
+ await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
+ Arg.Is(options => options.AutomaticTax.Enabled == true));
+ }
+
+ [Fact]
+ public async Task Run_BusinessOrganization_MakesCorrectInvocations_ReturnsBillingAddress()
+ {
+ var organization = new Organization
+ {
+ PlanType = PlanType.EnterpriseAnnually,
+ GatewayCustomerId = "cus_123",
+ GatewaySubscriptionId = "sub_123"
+ };
+
+ var input = new BillingAddress
+ {
+ Country = "US",
+ PostalCode = "12345",
+ Line1 = "123 Main St.",
+ Line2 = "Suite 100",
+ City = "New York",
+ State = "NY"
+ };
+
+ var customer = new Customer
+ {
+ Address = new Address
+ {
+ Country = "US",
+ PostalCode = "12345",
+ Line1 = "123 Main St.",
+ Line2 = "Suite 100",
+ City = "New York",
+ State = "NY"
+ },
+ Subscriptions = new StripeList
+ {
+ Data =
+ [
+ new Subscription
+ {
+ Id = organization.GatewaySubscriptionId,
+ AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
+ }
+ ]
+ }
+ };
+
+ _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options =>
+ options.Address.Matches(input) &&
+ options.HasExpansions("subscriptions", "tax_ids") &&
+ options.TaxExempt == TaxExempt.None
+ )).Returns(customer);
+
+ var result = await _command.Run(organization, input);
+
+ Assert.True(result.IsT0);
+ var output = result.AsT0;
+ Assert.Equivalent(input, output);
+
+ await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
+ Arg.Is(options => options.AutomaticTax.Enabled == true));
+ }
+
+ [Fact]
+ public async Task Run_BusinessOrganization_RemovingTaxId_MakesCorrectInvocations_ReturnsBillingAddress()
+ {
+ var organization = new Organization
+ {
+ PlanType = PlanType.EnterpriseAnnually,
+ GatewayCustomerId = "cus_123",
+ GatewaySubscriptionId = "sub_123"
+ };
+
+ var input = new BillingAddress
+ {
+ Country = "US",
+ PostalCode = "12345",
+ Line1 = "123 Main St.",
+ Line2 = "Suite 100",
+ City = "New York",
+ State = "NY"
+ };
+
+ var customer = new Customer
+ {
+ Address = new Address
+ {
+ Country = "US",
+ PostalCode = "12345",
+ Line1 = "123 Main St.",
+ Line2 = "Suite 100",
+ City = "New York",
+ State = "NY"
+ },
+ Id = organization.GatewayCustomerId,
+ Subscriptions = new StripeList
+ {
+ Data =
+ [
+ new Subscription
+ {
+ Id = organization.GatewaySubscriptionId,
+ AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
+ }
+ ]
+ },
+ TaxIds = new StripeList
+ {
+ Data =
+ [
+ new TaxId { Id = "tax_id_123", Type = "us_ein", Value = "123456789" }
+ ]
+ }
+ };
+
+ _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options =>
+ options.Address.Matches(input) &&
+ options.HasExpansions("subscriptions", "tax_ids") &&
+ options.TaxExempt == TaxExempt.None
+ )).Returns(customer);
+
+ var result = await _command.Run(organization, input);
+
+ Assert.True(result.IsT0);
+ var output = result.AsT0;
+ Assert.Equivalent(input, output);
+
+ await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
+ Arg.Is(options => options.AutomaticTax.Enabled == true));
+
+ await _stripeAdapter.Received(1).TaxIdDeleteAsync(customer.Id, "tax_id_123");
+ }
+
+ [Fact]
+ public async Task Run_NonUSBusinessOrganization_MakesCorrectInvocations_ReturnsBillingAddress()
+ {
+ var organization = new Organization
+ {
+ PlanType = PlanType.EnterpriseAnnually,
+ GatewayCustomerId = "cus_123",
+ GatewaySubscriptionId = "sub_123"
+ };
+
+ var input = new BillingAddress
+ {
+ Country = "DE",
+ PostalCode = "10115",
+ Line1 = "Friedrichstraße 123",
+ Line2 = "Stock 3",
+ City = "Berlin",
+ State = "Berlin"
+ };
+
+ var customer = new Customer
+ {
+ Address = new Address
+ {
+ Country = "DE",
+ PostalCode = "10115",
+ Line1 = "Friedrichstraße 123",
+ Line2 = "Stock 3",
+ City = "Berlin",
+ State = "Berlin"
+ },
+ Subscriptions = new StripeList
+ {
+ Data =
+ [
+ new Subscription
+ {
+ Id = organization.GatewaySubscriptionId,
+ AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
+ }
+ ]
+ }
+ };
+
+ _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options =>
+ options.Address.Matches(input) &&
+ options.HasExpansions("subscriptions", "tax_ids") &&
+ options.TaxExempt == TaxExempt.Reverse
+ )).Returns(customer);
+
+ var result = await _command.Run(organization, input);
+
+ Assert.True(result.IsT0);
+ var output = result.AsT0;
+ Assert.Equivalent(input, output);
+
+ await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
+ Arg.Is(options => options.AutomaticTax.Enabled == true));
+ }
+
+ [Fact]
+ public async Task Run_BusinessOrganizationWithSpanishCIF_MakesCorrectInvocations_ReturnsBillingAddress()
+ {
+ var organization = new Organization
+ {
+ PlanType = PlanType.EnterpriseAnnually,
+ GatewayCustomerId = "cus_123",
+ GatewaySubscriptionId = "sub_123"
+ };
+
+ var input = new BillingAddress
+ {
+ Country = "ES",
+ PostalCode = "28001",
+ Line1 = "Calle de Serrano 41",
+ Line2 = "Planta 3",
+ City = "Madrid",
+ State = "Madrid",
+ TaxId = new TaxID(TaxIdType.SpanishNIF, "A12345678")
+ };
+
+ var customer = new Customer
+ {
+ Address = new Address
+ {
+ Country = "ES",
+ PostalCode = "28001",
+ Line1 = "Calle de Serrano 41",
+ Line2 = "Planta 3",
+ City = "Madrid",
+ State = "Madrid"
+ },
+ Id = organization.GatewayCustomerId,
+ Subscriptions = new StripeList
+ {
+ Data =
+ [
+ new Subscription
+ {
+ Id = organization.GatewaySubscriptionId,
+ AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
+ }
+ ]
+ }
+ };
+
+ _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options =>
+ options.Address.Matches(input) &&
+ options.HasExpansions("subscriptions", "tax_ids") &&
+ options.TaxExempt == TaxExempt.Reverse
+ )).Returns(customer);
+
+ _stripeAdapter
+ .TaxIdCreateAsync(customer.Id,
+ Arg.Is(options => options.Type == TaxIdType.EUVAT))
+ .Returns(new TaxId { Type = TaxIdType.EUVAT, Value = "ESA12345678" });
+
+ var result = await _command.Run(organization, input);
+
+ Assert.True(result.IsT0);
+ var output = result.AsT0;
+ Assert.Equivalent(input with { TaxId = new TaxID(TaxIdType.EUVAT, "ESA12345678") }, output);
+
+ await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
+ Arg.Is(options => options.AutomaticTax.Enabled == true));
+
+ await _stripeAdapter.Received(1).TaxIdCreateAsync(organization.GatewayCustomerId, Arg.Is(
+ options => options.Type == TaxIdType.SpanishNIF &&
+ options.Value == input.TaxId.Value));
+ }
+}
diff --git a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs
new file mode 100644
index 0000000000..e7bc5c787c
--- /dev/null
+++ b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs
@@ -0,0 +1,399 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.Billing.Caches;
+using Bit.Core.Billing.Constants;
+using Bit.Core.Billing.Payment.Commands;
+using Bit.Core.Billing.Payment.Models;
+using Bit.Core.Billing.Services;
+using Bit.Core.Services;
+using Bit.Core.Settings;
+using Bit.Core.Test.Billing.Extensions;
+using Braintree;
+using Microsoft.Extensions.Logging;
+using NSubstitute;
+using Stripe;
+using Xunit;
+using Address = Stripe.Address;
+using Customer = Stripe.Customer;
+using PaymentMethod = Stripe.PaymentMethod;
+
+namespace Bit.Core.Test.Billing.Payment.Commands;
+
+using static StripeConstants;
+
+public class UpdatePaymentMethodCommandTests
+{
+ private readonly IBraintreeGateway _braintreeGateway = Substitute.For();
+ private readonly IGlobalSettings _globalSettings = Substitute.For();
+ private readonly ISetupIntentCache _setupIntentCache = Substitute.For();
+ private readonly IStripeAdapter _stripeAdapter = Substitute.For();
+ private readonly ISubscriberService _subscriberService = Substitute.For();
+ private readonly UpdatePaymentMethodCommand _command;
+
+ public UpdatePaymentMethodCommandTests()
+ {
+ _command = new UpdatePaymentMethodCommand(
+ _braintreeGateway,
+ _globalSettings,
+ Substitute.For>(),
+ _setupIntentCache,
+ _stripeAdapter,
+ _subscriberService);
+ }
+
+ [Fact]
+ public async Task Run_BankAccount_MakesCorrectInvocations_ReturnsMaskedBankAccount()
+ {
+ var organization = new Organization
+ {
+ Id = Guid.NewGuid()
+ };
+
+ var customer = new Customer
+ {
+ Address = new Address
+ {
+ Country = "US",
+ PostalCode = "12345"
+ },
+ Metadata = new Dictionary()
+ };
+
+ _subscriberService.GetCustomer(organization).Returns(customer);
+
+ const string token = "TOKEN";
+
+ var setupIntent = new SetupIntent
+ {
+ Id = "seti_123",
+ PaymentMethod =
+ new PaymentMethod
+ {
+ Type = "us_bank_account",
+ UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" }
+ },
+ NextAction = new SetupIntentNextAction
+ {
+ VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits()
+ },
+ Status = "requires_action"
+ };
+
+ _stripeAdapter.SetupIntentList(Arg.Is(options =>
+ options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]);
+
+ var result = await _command.Run(organization,
+ new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = token }, new BillingAddress
+ {
+ Country = "US",
+ PostalCode = "12345"
+ });
+
+ Assert.True(result.IsT0);
+ var maskedPaymentMethod = result.AsT0;
+ Assert.True(maskedPaymentMethod.IsT0);
+ var maskedBankAccount = maskedPaymentMethod.AsT0;
+ Assert.Equal("Chase", maskedBankAccount.BankName);
+ Assert.Equal("9999", maskedBankAccount.Last4);
+ Assert.False(maskedBankAccount.Verified);
+
+ await _setupIntentCache.Received(1).Set(organization.Id, setupIntent.Id);
+ }
+
+ [Fact]
+ public async Task Run_BankAccount_StripeToPayPal_MakesCorrectInvocations_ReturnsMaskedBankAccount()
+ {
+ var organization = new Organization
+ {
+ Id = Guid.NewGuid()
+ };
+
+ var customer = new Customer
+ {
+ Address = new Address
+ {
+ Country = "US",
+ PostalCode = "12345"
+ },
+ Id = "cus_123",
+ Metadata = new Dictionary
+ {
+ [MetadataKeys.BraintreeCustomerId] = "braintree_customer_id"
+ }
+ };
+
+ _subscriberService.GetCustomer(organization).Returns(customer);
+
+ const string token = "TOKEN";
+
+ var setupIntent = new SetupIntent
+ {
+ Id = "seti_123",
+ PaymentMethod =
+ new PaymentMethod
+ {
+ Type = "us_bank_account",
+ UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" }
+ },
+ NextAction = new SetupIntentNextAction
+ {
+ VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits()
+ },
+ Status = "requires_action"
+ };
+
+ _stripeAdapter.SetupIntentList(Arg.Is(options =>
+ options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]);
+
+ var result = await _command.Run(organization,
+ new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = token }, new BillingAddress
+ {
+ Country = "US",
+ PostalCode = "12345"
+ });
+
+ Assert.True(result.IsT0);
+ var maskedPaymentMethod = result.AsT0;
+ Assert.True(maskedPaymentMethod.IsT0);
+ var maskedBankAccount = maskedPaymentMethod.AsT0;
+ Assert.Equal("Chase", maskedBankAccount.BankName);
+ Assert.Equal("9999", maskedBankAccount.Last4);
+ Assert.False(maskedBankAccount.Verified);
+
+ await _setupIntentCache.Received(1).Set(organization.Id, setupIntent.Id);
+ await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, Arg.Is(options =>
+ options.Metadata[MetadataKeys.BraintreeCustomerId] == string.Empty &&
+ options.Metadata[MetadataKeys.RetiredBraintreeCustomerId] == "braintree_customer_id"));
+ }
+
+ [Fact]
+ public async Task Run_Card_MakesCorrectInvocations_ReturnsMaskedCard()
+ {
+ var organization = new Organization
+ {
+ Id = Guid.NewGuid()
+ };
+
+ var customer = new Customer
+ {
+ Address = new Address
+ {
+ Country = "US",
+ PostalCode = "12345"
+ },
+ Id = "cus_123",
+ Metadata = new Dictionary()
+ };
+
+ _subscriberService.GetCustomer(organization).Returns(customer);
+
+ const string token = "TOKEN";
+
+ _stripeAdapter
+ .PaymentMethodAttachAsync(token,
+ Arg.Is(options => options.Customer == customer.Id))
+ .Returns(new PaymentMethod
+ {
+ Type = "card",
+ Card = new PaymentMethodCard
+ {
+ Brand = "visa",
+ Last4 = "9999",
+ ExpMonth = 1,
+ ExpYear = 2028
+ }
+ });
+
+ var result = await _command.Run(organization,
+ new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = token }, new BillingAddress
+ {
+ Country = "US",
+ PostalCode = "12345"
+ });
+
+ Assert.True(result.IsT0);
+ var maskedPaymentMethod = result.AsT0;
+ Assert.True(maskedPaymentMethod.IsT1);
+ var maskedCard = maskedPaymentMethod.AsT1;
+ Assert.Equal("visa", maskedCard.Brand);
+ Assert.Equal("9999", maskedCard.Last4);
+ Assert.Equal("01/2028", maskedCard.Expiration);
+
+ await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id,
+ Arg.Is(options => options.InvoiceSettings.DefaultPaymentMethod == token));
+ }
+
+ [Fact]
+ public async Task Run_Card_PropagateBillingAddress_MakesCorrectInvocations_ReturnsMaskedCard()
+ {
+ var organization = new Organization
+ {
+ Id = Guid.NewGuid()
+ };
+
+ var customer = new Customer
+ {
+ Id = "cus_123",
+ Metadata = new Dictionary()
+ };
+
+ _subscriberService.GetCustomer(organization).Returns(customer);
+
+ const string token = "TOKEN";
+
+ _stripeAdapter
+ .PaymentMethodAttachAsync(token,
+ Arg.Is(options => options.Customer == customer.Id))
+ .Returns(new PaymentMethod
+ {
+ Type = "card",
+ Card = new PaymentMethodCard
+ {
+ Brand = "visa",
+ Last4 = "9999",
+ ExpMonth = 1,
+ ExpYear = 2028
+ }
+ });
+
+ var result = await _command.Run(organization,
+ new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = token }, new BillingAddress
+ {
+ Country = "US",
+ PostalCode = "12345"
+ });
+
+ Assert.True(result.IsT0);
+ var maskedPaymentMethod = result.AsT0;
+ Assert.True(maskedPaymentMethod.IsT1);
+ var maskedCard = maskedPaymentMethod.AsT1;
+ Assert.Equal("visa", maskedCard.Brand);
+ Assert.Equal("9999", maskedCard.Last4);
+ Assert.Equal("01/2028", maskedCard.Expiration);
+
+ await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id,
+ Arg.Is(options => options.InvoiceSettings.DefaultPaymentMethod == token));
+
+ await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id,
+ Arg.Is(options => options.Address.Country == "US" && options.Address.PostalCode == "12345"));
+ }
+
+ [Fact]
+ public async Task Run_PayPal_ExistingBraintreeCustomer_MakesCorrectInvocations_ReturnsMaskedPayPalAccount()
+ {
+ var organization = new Organization
+ {
+ Id = Guid.NewGuid()
+ };
+
+ var customer = new Customer
+ {
+ Address = new Address
+ {
+ Country = "US",
+ PostalCode = "12345"
+ },
+ Id = "cus_123",
+ Metadata = new Dictionary
+ {
+ [MetadataKeys.BraintreeCustomerId] = "braintree_customer_id"
+ }
+ };
+
+ _subscriberService.GetCustomer(organization).Returns(customer);
+
+ var customerGateway = Substitute.For();
+ var braintreeCustomer = Substitute.For();
+ braintreeCustomer.Id.Returns("braintree_customer_id");
+ var existing = Substitute.For();
+ existing.Email.Returns("user@gmail.com");
+ existing.IsDefault.Returns(true);
+ existing.Token.Returns("EXISTING");
+ braintreeCustomer.PaymentMethods.Returns([existing]);
+ customerGateway.FindAsync("braintree_customer_id").Returns(braintreeCustomer);
+ _braintreeGateway.Customer.Returns(customerGateway);
+
+ var paymentMethodGateway = Substitute.For();
+ var updated = Substitute.For();
+ updated.Email.Returns("user@gmail.com");
+ updated.Token.Returns("UPDATED");
+ var updatedResult = Substitute.For>();
+ updatedResult.Target.Returns(updated);
+ paymentMethodGateway.CreateAsync(Arg.Is(options =>
+ options.CustomerId == braintreeCustomer.Id && options.PaymentMethodNonce == "TOKEN"))
+ .Returns(updatedResult);
+ _braintreeGateway.PaymentMethod.Returns(paymentMethodGateway);
+
+ var result = await _command.Run(organization,
+ new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "TOKEN" },
+ new BillingAddress { Country = "US", PostalCode = "12345" });
+
+ Assert.True(result.IsT0);
+ var maskedPaymentMethod = result.AsT0;
+ Assert.True(maskedPaymentMethod.IsT2);
+ var maskedPayPalAccount = maskedPaymentMethod.AsT2;
+ Assert.Equal("user@gmail.com", maskedPayPalAccount.Email);
+
+ await customerGateway.Received(1).UpdateAsync(braintreeCustomer.Id,
+ Arg.Is(options => options.DefaultPaymentMethodToken == updated.Token));
+ await paymentMethodGateway.Received(1).DeleteAsync(existing.Token);
+ }
+
+ [Fact]
+ public async Task Run_PayPal_NewBraintreeCustomer_MakesCorrectInvocations_ReturnsMaskedPayPalAccount()
+ {
+ var organization = new Organization
+ {
+ Id = Guid.NewGuid()
+ };
+
+ var customer = new Customer
+ {
+ Address = new Address
+ {
+ Country = "US",
+ PostalCode = "12345"
+ },
+ Id = "cus_123",
+ Metadata = new Dictionary()
+ };
+
+ _subscriberService.GetCustomer(organization).Returns(customer);
+
+ _globalSettings.BaseServiceUri.Returns(new GlobalSettings.BaseServiceUriSettings(new GlobalSettings())
+ {
+ CloudRegion = "US"
+ });
+
+ var customerGateway = Substitute.For();
+ var braintreeCustomer = Substitute.For();
+ braintreeCustomer.Id.Returns("braintree_customer_id");
+ var payPalAccount = Substitute.For();
+ payPalAccount.Email.Returns("user@gmail.com");
+ payPalAccount.IsDefault.Returns(true);
+ payPalAccount.Token.Returns("NONCE");
+ braintreeCustomer.PaymentMethods.Returns([payPalAccount]);
+ var createResult = Substitute.For>();
+ createResult.Target.Returns(braintreeCustomer);
+ customerGateway.CreateAsync(Arg.Is(options =>
+ options.Id.StartsWith(organization.BraintreeCustomerIdPrefix() + organization.Id.ToString("N").ToLower()) &&
+ options.CustomFields[organization.BraintreeIdField()] == organization.Id.ToString() &&
+ options.CustomFields[organization.BraintreeCloudRegionField()] == "US" &&
+ options.Email == organization.BillingEmailAddress() &&
+ options.PaymentMethodNonce == "TOKEN")).Returns(createResult);
+ _braintreeGateway.Customer.Returns(customerGateway);
+
+ var result = await _command.Run(organization,
+ new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "TOKEN" },
+ new BillingAddress { Country = "US", PostalCode = "12345" });
+
+ Assert.True(result.IsT0);
+ var maskedPaymentMethod = result.AsT0;
+ Assert.True(maskedPaymentMethod.IsT2);
+ var maskedPayPalAccount = maskedPaymentMethod.AsT2;
+ Assert.Equal("user@gmail.com", maskedPayPalAccount.Email);
+
+ await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id,
+ Arg.Is(options =>
+ options.Metadata[MetadataKeys.BraintreeCustomerId] == "braintree_customer_id"));
+ }
+}
diff --git a/test/Core.Test/Billing/Payment/Commands/VerifyBankAccountCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/VerifyBankAccountCommandTests.cs
new file mode 100644
index 0000000000..4be5539cc8
--- /dev/null
+++ b/test/Core.Test/Billing/Payment/Commands/VerifyBankAccountCommandTests.cs
@@ -0,0 +1,81 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.Billing.Caches;
+using Bit.Core.Billing.Payment.Commands;
+using Bit.Core.Services;
+using Bit.Core.Test.Billing.Extensions;
+using Microsoft.Extensions.Logging;
+using NSubstitute;
+using Stripe;
+using Xunit;
+
+namespace Bit.Core.Test.Billing.Payment.Commands;
+
+public class VerifyBankAccountCommandTests
+{
+ private readonly ISetupIntentCache _setupIntentCache = Substitute.For();
+ private readonly IStripeAdapter _stripeAdapter = Substitute.For();
+ private readonly VerifyBankAccountCommand _command;
+
+ public VerifyBankAccountCommandTests()
+ {
+ _command = new VerifyBankAccountCommand(
+ Substitute.For>(),
+ _setupIntentCache,
+ _stripeAdapter);
+ }
+
+ [Fact]
+ public async Task Run_MakesCorrectInvocations_ReturnsMaskedBankAccount()
+ {
+ var organization = new Organization
+ {
+ Id = Guid.NewGuid(),
+ GatewayCustomerId = "cus_123"
+ };
+
+ const string setupIntentId = "seti_123";
+
+ _setupIntentCache.Get(organization.Id).Returns(setupIntentId);
+
+ var setupIntent = new SetupIntent
+ {
+ Id = setupIntentId,
+ PaymentMethodId = "pm_123",
+ PaymentMethod =
+ new PaymentMethod
+ {
+ Id = "pm_123",
+ Type = "us_bank_account",
+ UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" }
+ },
+ NextAction = new SetupIntentNextAction
+ {
+ VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits()
+ },
+ Status = "requires_action"
+ };
+
+ _stripeAdapter.SetupIntentGet(setupIntentId,
+ Arg.Is(options => options.HasExpansions("payment_method"))).Returns(setupIntent);
+
+ _stripeAdapter.PaymentMethodAttachAsync(setupIntent.PaymentMethodId,
+ Arg.Is(options => options.Customer == organization.GatewayCustomerId))
+ .Returns(setupIntent.PaymentMethod);
+
+ var result = await _command.Run(organization, "DESCRIPTOR_CODE");
+
+ Assert.True(result.IsT0);
+ var maskedPaymentMethod = result.AsT0;
+ Assert.True(maskedPaymentMethod.IsT0);
+ var maskedBankAccount = maskedPaymentMethod.AsT0;
+ Assert.Equal("Chase", maskedBankAccount.BankName);
+ Assert.Equal("9999", maskedBankAccount.Last4);
+ Assert.True(maskedBankAccount.Verified);
+
+ await _stripeAdapter.Received(1).SetupIntentVerifyMicroDeposit(setupIntent.Id,
+ Arg.Is(options => options.DescriptorCode == "DESCRIPTOR_CODE"));
+
+ await _stripeAdapter.Received(1).CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(
+ options => options.InvoiceSettings.DefaultPaymentMethod == setupIntent.PaymentMethodId));
+ }
+}
diff --git a/test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs b/test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs
new file mode 100644
index 0000000000..345f2dfab8
--- /dev/null
+++ b/test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs
@@ -0,0 +1,63 @@
+using System.Text.Json;
+using Bit.Core.Billing.Payment.Models;
+using Xunit;
+
+namespace Bit.Core.Test.Billing.Payment.Models;
+
+public class MaskedPaymentMethodTests
+{
+ [Fact]
+ public void Write_Read_BankAccount_Succeeds()
+ {
+ MaskedPaymentMethod input = new MaskedBankAccount
+ {
+ BankName = "Chase",
+ Last4 = "9999",
+ Verified = true
+ };
+
+ var json = JsonSerializer.Serialize(input);
+
+ var output = JsonSerializer.Deserialize(json);
+ Assert.NotNull(output);
+ Assert.True(output.IsT0);
+
+ Assert.Equivalent(input.AsT0, output.AsT0);
+ }
+
+ [Fact]
+ public void Write_Read_Card_Succeeds()
+ {
+ MaskedPaymentMethod input = new MaskedCard
+ {
+ Brand = "visa",
+ Last4 = "9999",
+ Expiration = "01/2028"
+ };
+
+ var json = JsonSerializer.Serialize(input);
+
+ var output = JsonSerializer.Deserialize(json);
+ Assert.NotNull(output);
+ Assert.True(output.IsT1);
+
+ Assert.Equivalent(input.AsT1, output.AsT1);
+ }
+
+ [Fact]
+ public void Write_Read_PayPal_Succeeds()
+ {
+ MaskedPaymentMethod input = new MaskedPayPalAccount
+ {
+ Email = "paypal-user@gmail.com"
+ };
+
+ var json = JsonSerializer.Serialize(input);
+
+ var output = JsonSerializer.Deserialize(json);
+ Assert.NotNull(output);
+ Assert.True(output.IsT2);
+
+ Assert.Equivalent(input.AsT2, output.AsT2);
+ }
+}
diff --git a/test/Core.Test/Billing/Payment/Queries/GetBillingAddressQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetBillingAddressQueryTests.cs
new file mode 100644
index 0000000000..048c143a0e
--- /dev/null
+++ b/test/Core.Test/Billing/Payment/Queries/GetBillingAddressQueryTests.cs
@@ -0,0 +1,204 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.Billing.Enums;
+using Bit.Core.Billing.Payment.Models;
+using Bit.Core.Billing.Payment.Queries;
+using Bit.Core.Billing.Services;
+using Bit.Core.Entities;
+using Bit.Core.Test.Billing.Extensions;
+using NSubstitute;
+using Stripe;
+using Xunit;
+
+namespace Bit.Core.Test.Billing.Payment.Queries;
+
+public class GetBillingAddressQueryTests
+{
+ private readonly ISubscriberService _subscriberService = Substitute.For();
+ private readonly GetBillingAddressQuery _query;
+
+ public GetBillingAddressQueryTests()
+ {
+ _query = new GetBillingAddressQuery(_subscriberService);
+ }
+
+ [Fact]
+ public async Task Run_ForUserWithNoAddress_ReturnsNull()
+ {
+ var user = new User();
+
+ var customer = new Customer();
+
+ _subscriberService.GetCustomer(user, Arg.Is(
+ options => options.Expand == null)).Returns(customer);
+
+ var billingAddress = await _query.Run(user);
+
+ Assert.Null(billingAddress);
+ }
+
+ [Fact]
+ public async Task Run_ForUserWithAddress_ReturnsBillingAddress()
+ {
+ var user = new User();
+
+ var address = GetAddress();
+
+ var customer = new Customer
+ {
+ Address = address
+ };
+
+ _subscriberService.GetCustomer(user, Arg.Is(
+ options => options.Expand == null)).Returns(customer);
+
+ var billingAddress = await _query.Run(user);
+
+ AssertEquality(address, billingAddress);
+ }
+
+ [Fact]
+ public async Task Run_ForPersonalOrganizationWithNoAddress_ReturnsNull()
+ {
+ var organization = new Organization
+ {
+ PlanType = PlanType.FamiliesAnnually
+ };
+
+ var customer = new Customer();
+
+ _subscriberService.GetCustomer(organization, Arg.Is(
+ options => options.Expand == null)).Returns(customer);
+
+ var billingAddress = await _query.Run(organization);
+
+ Assert.Null(billingAddress);
+ }
+
+ [Fact]
+ public async Task Run_ForPersonalOrganizationWithAddress_ReturnsBillingAddress()
+ {
+ var organization = new Organization
+ {
+ PlanType = PlanType.FamiliesAnnually
+ };
+
+ var address = GetAddress();
+
+ var customer = new Customer
+ {
+ Address = address
+ };
+
+ _subscriberService.GetCustomer(organization, Arg.Is(
+ options => options.Expand == null)).Returns(customer);
+
+ var billingAddress = await _query.Run(organization);
+
+ AssertEquality(customer.Address, billingAddress);
+ }
+
+ [Fact]
+ public async Task Run_ForBusinessOrganizationWithNoAddress_ReturnsNull()
+ {
+ var organization = new Organization
+ {
+ PlanType = PlanType.EnterpriseAnnually
+ };
+
+ var customer = new Customer();
+
+ _subscriberService.GetCustomer(organization, Arg.Is(
+ options => options.HasExpansions("tax_ids"))).Returns(customer);
+
+ var billingAddress = await _query.Run(organization);
+
+ Assert.Null(billingAddress);
+ }
+
+ [Fact]
+ public async Task Run_ForBusinessOrganizationWithAddressAndTaxId_ReturnsBillingAddressWithTaxId()
+ {
+ var organization = new Organization
+ {
+ PlanType = PlanType.EnterpriseAnnually
+ };
+
+ var address = GetAddress();
+
+ var taxId = GetTaxId();
+
+ var customer = new Customer
+ {
+ Address = address,
+ TaxIds = new StripeList
+ {
+ Data = [taxId]
+ }
+ };
+
+ _subscriberService.GetCustomer(organization, Arg.Is(
+ options => options.HasExpansions("tax_ids"))).Returns(customer);
+
+ var billingAddress = await _query.Run(organization);
+
+ AssertEquality(address, taxId, billingAddress);
+ }
+
+ [Fact]
+ public async Task Run_ForProviderWithAddressAndTaxId_ReturnsBillingAddressWithTaxId()
+ {
+ var provider = new Provider();
+
+ var address = GetAddress();
+
+ var taxId = GetTaxId();
+
+ var customer = new Customer
+ {
+ Address = address,
+ TaxIds = new StripeList
+ {
+ Data = [taxId]
+ }
+ };
+
+ _subscriberService.GetCustomer(provider, Arg.Is(
+ options => options.HasExpansions("tax_ids"))).Returns(customer);
+
+ var billingAddress = await _query.Run(provider);
+
+ AssertEquality(address, taxId, billingAddress);
+ }
+
+ private static void AssertEquality(Address address, BillingAddress? billingAddress)
+ {
+ Assert.NotNull(billingAddress);
+ Assert.Equal(address.Country, billingAddress.Country);
+ Assert.Equal(address.PostalCode, billingAddress.PostalCode);
+ Assert.Equal(address.Line1, billingAddress.Line1);
+ Assert.Equal(address.Line2, billingAddress.Line2);
+ Assert.Equal(address.City, billingAddress.City);
+ Assert.Equal(address.State, billingAddress.State);
+ }
+
+ private static void AssertEquality(Address address, TaxId taxId, BillingAddress? billingAddress)
+ {
+ AssertEquality(address, billingAddress);
+ Assert.NotNull(billingAddress!.TaxId);
+ Assert.Equal(taxId.Type, billingAddress.TaxId!.Code);
+ Assert.Equal(taxId.Value, billingAddress.TaxId!.Value);
+ }
+
+ private static Address GetAddress() => new()
+ {
+ Country = "US",
+ PostalCode = "12345",
+ Line1 = "123 Main St.",
+ Line2 = "Suite 100",
+ City = "New York",
+ State = "NY"
+ };
+
+ private static TaxId GetTaxId() => new() { Type = "us_ein", Value = "123456789" };
+}
diff --git a/test/Core.Test/Billing/Payment/Queries/GetCreditQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetCreditQueryTests.cs
new file mode 100644
index 0000000000..55f5e85009
--- /dev/null
+++ b/test/Core.Test/Billing/Payment/Queries/GetCreditQueryTests.cs
@@ -0,0 +1,41 @@
+using Bit.Core.Billing.Payment.Queries;
+using Bit.Core.Billing.Services;
+using Bit.Core.Entities;
+using NSubstitute;
+using NSubstitute.ReturnsExtensions;
+using Stripe;
+using Xunit;
+
+namespace Bit.Core.Test.Billing.Payment.Queries;
+
+public class GetCreditQueryTests
+{
+ private readonly ISubscriberService _subscriberService = Substitute.For();
+ private readonly GetCreditQuery _query;
+
+ public GetCreditQueryTests()
+ {
+ _query = new GetCreditQuery(_subscriberService);
+ }
+
+ [Fact]
+ public async Task Run_NoCustomer_ReturnsNull()
+ {
+ _subscriberService.GetCustomer(Arg.Any()).ReturnsNull();
+
+ var credit = await _query.Run(Substitute.For());
+
+ Assert.Null(credit);
+ }
+
+ [Fact]
+ public async Task Run_ReturnsCredit()
+ {
+ _subscriberService.GetCustomer(Arg.Any()).Returns(new Customer { Balance = -1000 });
+
+ var credit = await _query.Run(Substitute.For());
+
+ Assert.NotNull(credit);
+ Assert.Equal(10M, credit);
+ }
+}
diff --git a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs
new file mode 100644
index 0000000000..4d82b4b5c9
--- /dev/null
+++ b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs
@@ -0,0 +1,327 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.Billing.Caches;
+using Bit.Core.Billing.Constants;
+using Bit.Core.Billing.Payment.Queries;
+using Bit.Core.Billing.Services;
+using Bit.Core.Services;
+using Bit.Core.Test.Billing.Extensions;
+using Braintree;
+using Microsoft.Extensions.Logging;
+using NSubstitute;
+using Stripe;
+using Xunit;
+using Customer = Stripe.Customer;
+using PaymentMethod = Stripe.PaymentMethod;
+
+namespace Bit.Core.Test.Billing.Payment.Queries;
+
+using static StripeConstants;
+
+public class GetPaymentMethodQueryTests
+{
+ private readonly IBraintreeGateway _braintreeGateway = Substitute.For();
+ private readonly ISetupIntentCache _setupIntentCache = Substitute.For();
+ private readonly IStripeAdapter _stripeAdapter = Substitute.For();
+ private readonly ISubscriberService _subscriberService = Substitute.For