From 7f65a655d4dbfe578b8b5f982de34a6dc757562e Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Thu, 10 Jul 2025 08:32:25 -0500 Subject: [PATCH] [PM-21881] Manage payment details outside of checkout (#6032) * Add feature flag * Further establish billing command pattern and use in PreviewTaxAmountCommand * Add billing address models/commands/queries/tests * Update TypeReadingJsonConverter to account for new union types * Add payment method models/commands/queries/tests * Add credit models/commands/queries/tests * Add command/query registrations * Add new endpoints to support new command model and payment functionality * Run dotnet format * Add InjectUserAttribute for easier AccountBillilngVNextController handling * Add InjectOrganizationAttribute for easier OrganizationBillingVNextController handling * Add InjectProviderAttribute for easier ProviderBillingVNextController handling * Add XML documentation for billing command pipeline * Fix StripeConstants post-nullability * More nullability cleanup * Run dotnet format --- .../Attributes/InjectOrganizationAttribute.cs | 61 +++ .../Attributes/InjectProviderAttribute.cs | 80 ++++ .../Billing/Attributes/InjectUserAttribute.cs | 53 +++ .../Controllers/BaseBillingController.cs | 44 +- src/Api/Billing/Controllers/TaxController.cs | 5 +- .../VNext/AccountBillingVNextController.cs | 64 +++ .../OrganizationBillingVNextController.cs | 107 +++++ .../VNext/ProviderBillingVNextController.cs | 97 +++++ .../Requests/Payment/BillingAddressRequest.cs | 20 + .../Requests/Payment/BitPayCreditRequest.cs | 13 + .../Payment/CheckoutBillingAddressRequest.cs | 24 ++ .../Payment/MinimalBillingAddressRequest.cs | 16 + .../Payment/TokenizedPaymentMethodRequest.cs | 39 ++ .../Payment/VerifyBankAccountRequest.cs | 9 + .../ManageOrganizationBillingRequirement.cs | 18 + src/Api/Utilities/StringMatchesAttribute.cs | 18 + src/Core/Billing/Commands/BillingCommand.cs | 62 +++ .../Billing/Commands/BillingCommandResult.cs | 31 ++ src/Core/Billing/Constants/StripeConstants.cs | 12 +- .../Extensions/ServiceCollectionExtensions.cs | 2 + .../Extensions/SubscriberExtensions.cs | 16 +- .../Billing/Models/BillingCommandResult.cs | 36 -- .../Billing/Payment/Clients/BitPayClient.cs | 24 ++ .../CreateBitPayInvoiceForCreditCommand.cs | 59 +++ .../Commands/UpdateBillingAddressCommand.cs | 129 ++++++ .../Commands/UpdatePaymentMethodCommand.cs | 205 +++++++++ .../Commands/VerifyBankAccountCommand.cs | 63 +++ .../Billing/Payment/Models/BillingAddress.cs | 30 ++ .../Payment/Models/MaskedPaymentMethod.cs | 120 ++++++ .../Payment/Models/ProductUsageType.cs | 7 + .../Models/TokenizablePaymentMethodType.cs | 8 + .../Payment/Models/TokenizedPaymentMethod.cs | 8 + .../Payment/Queries/GetBillingAddressQuery.cs | 41 ++ .../Billing/Payment/Queries/GetCreditQuery.cs | 26 ++ .../Payment/Queries/GetPaymentMethodQuery.cs | 96 +++++ src/Core/Billing/Payment/Registrations.cs | 24 ++ .../Pricing/JSON/TypeReadingJsonConverter.cs | 12 +- .../Tax/Commands/PreviewTaxAmountCommand.cs | 172 ++++---- src/Core/Constants.cs | 1 + .../InjectOrganizationAttributeTests.cs | 132 ++++++ .../InjectProviderAttributeTests.cs | 190 +++++++++ .../Attributes/InjectUserAttributesTests.cs | 129 ++++++ .../Billing/Extensions/StripeExtensions.cs | 18 + ...reateBitPayInvoiceForCreditCommandTests.cs | 94 +++++ .../UpdateBillingAddressCommandTests.cs | 349 +++++++++++++++ .../UpdatePaymentMethodCommandTests.cs | 399 ++++++++++++++++++ .../Commands/VerifyBankAccountCommandTests.cs | 81 ++++ .../Models/MaskedPaymentMethodTests.cs | 63 +++ .../Queries/GetBillingAddressQueryTests.cs | 204 +++++++++ .../Payment/Queries/GetCreditQueryTests.cs | 41 ++ .../Queries/GetPaymentMethodQueryTests.cs | 327 ++++++++++++++ .../Commands/PreviewTaxAmountCommandTests.cs | 72 +--- 52 files changed, 3736 insertions(+), 215 deletions(-) create mode 100644 src/Api/Billing/Attributes/InjectOrganizationAttribute.cs create mode 100644 src/Api/Billing/Attributes/InjectProviderAttribute.cs create mode 100644 src/Api/Billing/Attributes/InjectUserAttribute.cs create mode 100644 src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs create mode 100644 src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs create mode 100644 src/Api/Billing/Controllers/VNext/ProviderBillingVNextController.cs create mode 100644 src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs create mode 100644 src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs create mode 100644 src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs create mode 100644 src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs create mode 100644 src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs create mode 100644 src/Api/Billing/Models/Requests/Payment/VerifyBankAccountRequest.cs create mode 100644 src/Api/Billing/Models/Requirements/ManageOrganizationBillingRequirement.cs create mode 100644 src/Api/Utilities/StringMatchesAttribute.cs create mode 100644 src/Core/Billing/Commands/BillingCommand.cs create mode 100644 src/Core/Billing/Commands/BillingCommandResult.cs delete mode 100644 src/Core/Billing/Models/BillingCommandResult.cs create mode 100644 src/Core/Billing/Payment/Clients/BitPayClient.cs create mode 100644 src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs create mode 100644 src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs create mode 100644 src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs create mode 100644 src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs create mode 100644 src/Core/Billing/Payment/Models/BillingAddress.cs create mode 100644 src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs create mode 100644 src/Core/Billing/Payment/Models/ProductUsageType.cs create mode 100644 src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs create mode 100644 src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs create mode 100644 src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs create mode 100644 src/Core/Billing/Payment/Queries/GetCreditQuery.cs create mode 100644 src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs create mode 100644 src/Core/Billing/Payment/Registrations.cs create mode 100644 test/Api.Test/Billing/Attributes/InjectOrganizationAttributeTests.cs create mode 100644 test/Api.Test/Billing/Attributes/InjectProviderAttributeTests.cs create mode 100644 test/Api.Test/Billing/Attributes/InjectUserAttributesTests.cs create mode 100644 test/Core.Test/Billing/Extensions/StripeExtensions.cs create mode 100644 test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs create mode 100644 test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs create mode 100644 test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs create mode 100644 test/Core.Test/Billing/Payment/Commands/VerifyBankAccountCommandTests.cs create mode 100644 test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs create mode 100644 test/Core.Test/Billing/Payment/Queries/GetBillingAddressQueryTests.cs create mode 100644 test/Core.Test/Billing/Payment/Queries/GetCreditQueryTests.cs create mode 100644 test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs diff --git a/src/Api/Billing/Attributes/InjectOrganizationAttribute.cs b/src/Api/Billing/Attributes/InjectOrganizationAttribute.cs new file mode 100644 index 0000000000..f4c2a8c637 --- /dev/null +++ b/src/Api/Billing/Attributes/InjectOrganizationAttribute.cs @@ -0,0 +1,61 @@ +#nullable enable +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Models.Api; +using Bit.Core.Repositories; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Filters; + +namespace Bit.Api.Billing.Attributes; + +/// +/// An action filter that facilitates the injection of a parameter into the executing action method arguments. +/// +/// +/// This attribute retrieves the organization associated with the 'organizationId' included in the executing context's route data. If the organization cannot be found, +/// the request is terminated with a not found response. +/// The injected +/// parameter must be marked with a [BindNever] attribute to short-circuit the model-binding system. +/// +/// +/// EndpointAsync([BindNever] Organization organization) +/// ]]> +/// +/// +public class InjectOrganizationAttribute : ActionFilterAttribute +{ + public override async Task OnActionExecutionAsync( + ActionExecutingContext context, + ActionExecutionDelegate next) + { + if (!context.RouteData.Values.TryGetValue("organizationId", out var routeValue) || + !Guid.TryParse(routeValue?.ToString(), out var organizationId)) + { + context.Result = new BadRequestObjectResult(new ErrorResponseModel("Route parameter 'organizationId' is missing or invalid.")); + return; + } + + var organizationRepository = context.HttpContext.RequestServices + .GetRequiredService(); + + var organization = await organizationRepository.GetByIdAsync(organizationId); + + if (organization == null) + { + context.Result = new NotFoundObjectResult(new ErrorResponseModel("Organization not found.")); + return; + } + + var organizationParameter = context.ActionDescriptor.Parameters + .FirstOrDefault(p => p.ParameterType == typeof(Organization)); + + if (organizationParameter != null) + { + context.ActionArguments[organizationParameter.Name] = organization; + } + + await next(); + } +} diff --git a/src/Api/Billing/Attributes/InjectProviderAttribute.cs b/src/Api/Billing/Attributes/InjectProviderAttribute.cs new file mode 100644 index 0000000000..e65dda37c3 --- /dev/null +++ b/src/Api/Billing/Attributes/InjectProviderAttribute.cs @@ -0,0 +1,80 @@ +#nullable enable +using Bit.Api.Models.Public.Response; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Context; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Filters; + +namespace Bit.Api.Billing.Attributes; + +/// +/// An action filter that facilitates the injection of a parameter into the executing action method arguments after performing an authorization check. +/// +/// +/// This attribute retrieves the provider associated with the 'providerId' included in the executing context's route data. If the provider cannot be found, +/// the request is terminated with a not-found response. It then checks the authorization level for the provider using the provided . +/// If this check fails, the request is terminated with an unauthorized response. +/// The injected +/// parameter must be marked with a [BindNever] attribute to short-circuit the model-binding system. +/// +/// +/// EndpointAsync([BindNever] Provider provider) +/// ]]> +/// +/// The desired access level for the authorization check. +/// +public class InjectProviderAttribute(ProviderUserType providerUserType) : ActionFilterAttribute +{ + public override async Task OnActionExecutionAsync( + ActionExecutingContext context, + ActionExecutionDelegate next) + { + if (!context.RouteData.Values.TryGetValue("providerId", out var routeValue) || + !Guid.TryParse(routeValue?.ToString(), out var providerId)) + { + context.Result = new BadRequestObjectResult(new ErrorResponseModel("Route parameter 'providerId' is missing or invalid.")); + return; + } + + var providerRepository = context.HttpContext.RequestServices + .GetRequiredService(); + + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + context.Result = new NotFoundObjectResult(new ErrorResponseModel("Provider not found.")); + return; + } + + var currentContext = context.HttpContext.RequestServices.GetRequiredService(); + + var unauthorized = providerUserType switch + { + ProviderUserType.ProviderAdmin => !currentContext.ProviderProviderAdmin(providerId), + ProviderUserType.ServiceUser => !currentContext.ProviderUser(providerId), + _ => false + }; + + if (unauthorized) + { + context.Result = new UnauthorizedObjectResult(new ErrorResponseModel("Unauthorized.")); + return; + } + + var providerParameter = context.ActionDescriptor.Parameters + .FirstOrDefault(p => p.ParameterType == typeof(Provider)); + + if (providerParameter != null) + { + context.ActionArguments[providerParameter.Name] = provider; + } + + await next(); + } +} diff --git a/src/Api/Billing/Attributes/InjectUserAttribute.cs b/src/Api/Billing/Attributes/InjectUserAttribute.cs new file mode 100644 index 0000000000..0b614bdc44 --- /dev/null +++ b/src/Api/Billing/Attributes/InjectUserAttribute.cs @@ -0,0 +1,53 @@ +#nullable enable +using Bit.Core.Entities; +using Bit.Core.Models.Api; +using Bit.Core.Services; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Filters; + +namespace Bit.Api.Billing.Attributes; + +/// +/// An action filter that facilitates the injection of a parameter into the executing action method arguments. +/// +/// +/// This attribute retrieves the authorized user associated with the current HTTP context using the service. +/// If the user is unauthorized or cannot be found, the request is terminated with an unauthorized response. +/// The injected +/// parameter must be marked with a [BindNever] attribute to short-circuit the model-binding system. +/// +/// +/// EndpointAsync([BindNever] User user) +/// ]]> +/// +/// +public class InjectUserAttribute : ActionFilterAttribute +{ + public override async Task OnActionExecutionAsync( + ActionExecutingContext context, + ActionExecutionDelegate next) + { + var userService = context.HttpContext.RequestServices.GetRequiredService(); + + var user = await userService.GetUserByPrincipalAsync(context.HttpContext.User); + + if (user == null) + { + context.Result = new UnauthorizedObjectResult(new ErrorResponseModel("Unauthorized.")); + return; + } + + var userParameter = + context.ActionDescriptor.Parameters.FirstOrDefault(parameter => parameter.ParameterType == typeof(User)); + + if (userParameter != null) + { + context.ActionArguments[userParameter.Name] = user; + } + + await next(); + } +} diff --git a/src/Api/Billing/Controllers/BaseBillingController.cs b/src/Api/Billing/Controllers/BaseBillingController.cs index 5f7005fdfc..057c8309fb 100644 --- a/src/Api/Billing/Controllers/BaseBillingController.cs +++ b/src/Api/Billing/Controllers/BaseBillingController.cs @@ -1,4 +1,6 @@ -using Bit.Core.Models.Api; +#nullable enable +using Bit.Core.Billing.Commands; +using Bit.Core.Models.Api; using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Mvc; @@ -6,20 +8,50 @@ namespace Bit.Api.Billing.Controllers; public abstract class BaseBillingController : Controller { + /// + /// Processes the result of a billing command and converts it to an appropriate HTTP result response. + /// + /// + /// Result to response mappings: + /// + /// : 200 OK + /// : 400 BAD_REQUEST + /// : 409 CONFLICT + /// : 500 INTERNAL_SERVER_ERROR + /// + /// + /// The type of the successful result. + /// The result of executing the billing command. + /// An HTTP result response representing the outcome of the command execution. + protected static IResult Handle(BillingCommandResult result) => + result.Match( + TypedResults.Ok, + badRequest => Error.BadRequest(badRequest.Response), + conflict => Error.Conflict(conflict.Response), + unhandled => Error.ServerError(unhandled.Response, unhandled.Exception)); + protected static class Error { - public static BadRequest BadRequest(Dictionary> errors) => - TypedResults.BadRequest(new ErrorResponseModel(errors)); - public static BadRequest BadRequest(string message) => TypedResults.BadRequest(new ErrorResponseModel(message)); + public static JsonHttpResult Conflict(string message) => + TypedResults.Json( + new ErrorResponseModel(message), + statusCode: StatusCodes.Status409Conflict); + public static NotFound NotFound() => TypedResults.NotFound(new ErrorResponseModel("Resource not found.")); - public static JsonHttpResult ServerError(string message = "Something went wrong with your request. Please contact support.") => + public static JsonHttpResult ServerError( + string message = "Something went wrong with your request. Please contact support for assistance.", + Exception? exception = null) => TypedResults.Json( - new ErrorResponseModel(message), + exception == null ? new ErrorResponseModel(message) : new ErrorResponseModel(message) + { + ExceptionMessage = exception.Message, + ExceptionStackTrace = exception.StackTrace + }, statusCode: StatusCodes.Status500InternalServerError); public static JsonHttpResult Unauthorized(string message = "Unauthorized.") => diff --git a/src/Api/Billing/Controllers/TaxController.cs b/src/Api/Billing/Controllers/TaxController.cs index 7b8b9d960f..d2c1c36726 100644 --- a/src/Api/Billing/Controllers/TaxController.cs +++ b/src/Api/Billing/Controllers/TaxController.cs @@ -28,9 +28,6 @@ public class TaxController( var result = await previewTaxAmountCommand.Run(parameters); - return result.Match( - taxAmount => TypedResults.Ok(new { TaxAmount = taxAmount }), - badRequest => Error.BadRequest(badRequest.TranslationKey), - unhandled => Error.ServerError(unhandled.TranslationKey)); + return Handle(result); } } diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs new file mode 100644 index 0000000000..e3b702e36d --- /dev/null +++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs @@ -0,0 +1,64 @@ +#nullable enable +using Bit.Api.Billing.Attributes; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Payment.Commands; +using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Entities; +using Bit.Core.Utilities; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.ModelBinding; + +namespace Bit.Api.Billing.Controllers.VNext; + +[Authorize("Application")] +[Route("account/billing/vnext")] +[SelfHosted(NotSelfHostedOnly = true)] +public class AccountBillingVNextController( + ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand, + IGetCreditQuery getCreditQuery, + IGetPaymentMethodQuery getPaymentMethodQuery, + IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingController +{ + [HttpGet("credit")] + [InjectUser] + public async Task GetCreditAsync( + [BindNever] User user) + { + var credit = await getCreditQuery.Run(user); + return TypedResults.Ok(credit); + } + + [HttpPost("credit/bitpay")] + [InjectUser] + public async Task AddCreditViaBitPayAsync( + [BindNever] User user, + [FromBody] BitPayCreditRequest request) + { + var result = await createBitPayInvoiceForCreditCommand.Run( + user, + request.Amount, + request.RedirectUrl); + return Handle(result); + } + + [HttpGet("payment-method")] + [InjectUser] + public async Task GetPaymentMethodAsync( + [BindNever] User user) + { + var paymentMethod = await getPaymentMethodQuery.Run(user); + return TypedResults.Ok(paymentMethod); + } + + [HttpPut("payment-method")] + [InjectUser] + public async Task UpdatePaymentMethodAsync( + [BindNever] User user, + [FromBody] TokenizedPaymentMethodRequest request) + { + var (paymentMethod, billingAddress) = request.ToDomain(); + var result = await updatePaymentMethodCommand.Run(user, paymentMethod, billingAddress); + return Handle(result); + } +} diff --git a/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs new file mode 100644 index 0000000000..429f2065f6 --- /dev/null +++ b/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs @@ -0,0 +1,107 @@ +#nullable enable +using Bit.Api.AdminConsole.Authorization; +using Bit.Api.Billing.Attributes; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Api.Billing.Models.Requirements; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Payment.Commands; +using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Utilities; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.ModelBinding; +// ReSharper disable RouteTemplates.MethodMissingRouteParameters + +namespace Bit.Api.Billing.Controllers.VNext; + +[Authorize("Application")] +[Route("organizations/{organizationId:guid}/billing/vnext")] +[SelfHosted(NotSelfHostedOnly = true)] +public class OrganizationBillingVNextController( + ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand, + IGetBillingAddressQuery getBillingAddressQuery, + IGetCreditQuery getCreditQuery, + IGetPaymentMethodQuery getPaymentMethodQuery, + IUpdateBillingAddressCommand updateBillingAddressCommand, + IUpdatePaymentMethodCommand updatePaymentMethodCommand, + IVerifyBankAccountCommand verifyBankAccountCommand) : BaseBillingController +{ + [Authorize] + [HttpGet("address")] + [InjectOrganization] + public async Task GetBillingAddressAsync( + [BindNever] Organization organization) + { + var billingAddress = await getBillingAddressQuery.Run(organization); + return TypedResults.Ok(billingAddress); + } + + [Authorize] + [HttpPut("address")] + [InjectOrganization] + public async Task UpdateBillingAddressAsync( + [BindNever] Organization organization, + [FromBody] BillingAddressRequest request) + { + var billingAddress = request.ToDomain(); + var result = await updateBillingAddressCommand.Run(organization, billingAddress); + return Handle(result); + } + + [Authorize] + [HttpGet("credit")] + [InjectOrganization] + public async Task GetCreditAsync( + [BindNever] Organization organization) + { + var credit = await getCreditQuery.Run(organization); + return TypedResults.Ok(credit); + } + + [Authorize] + [HttpPost("credit/bitpay")] + [InjectOrganization] + public async Task AddCreditViaBitPayAsync( + [BindNever] Organization organization, + [FromBody] BitPayCreditRequest request) + { + var result = await createBitPayInvoiceForCreditCommand.Run( + organization, + request.Amount, + request.RedirectUrl); + return Handle(result); + } + + [Authorize] + [HttpGet("payment-method")] + [InjectOrganization] + public async Task GetPaymentMethodAsync( + [BindNever] Organization organization) + { + var paymentMethod = await getPaymentMethodQuery.Run(organization); + return TypedResults.Ok(paymentMethod); + } + + [Authorize] + [HttpPut("payment-method")] + [InjectOrganization] + public async Task UpdatePaymentMethodAsync( + [BindNever] Organization organization, + [FromBody] TokenizedPaymentMethodRequest request) + { + var (paymentMethod, billingAddress) = request.ToDomain(); + var result = await updatePaymentMethodCommand.Run(organization, paymentMethod, billingAddress); + return Handle(result); + } + + [Authorize] + [HttpPost("payment-method/verify-bank-account")] + [InjectOrganization] + public async Task VerifyBankAccountAsync( + [BindNever] Organization organization, + [FromBody] VerifyBankAccountRequest request) + { + var result = await verifyBankAccountCommand.Run(organization, request.DescriptorCode); + return Handle(result); + } +} diff --git a/src/Api/Billing/Controllers/VNext/ProviderBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/ProviderBillingVNextController.cs new file mode 100644 index 0000000000..be7963236f --- /dev/null +++ b/src/Api/Billing/Controllers/VNext/ProviderBillingVNextController.cs @@ -0,0 +1,97 @@ +#nullable enable +using Bit.Api.Billing.Attributes; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Payment.Commands; +using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Utilities; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.ModelBinding; +// ReSharper disable RouteTemplates.MethodMissingRouteParameters + +namespace Bit.Api.Billing.Controllers.VNext; + +[Route("providers/{providerId:guid}/billing/vnext")] +[SelfHosted(NotSelfHostedOnly = true)] +public class ProviderBillingVNextController( + ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand, + IGetBillingAddressQuery getBillingAddressQuery, + IGetCreditQuery getCreditQuery, + IGetPaymentMethodQuery getPaymentMethodQuery, + IUpdateBillingAddressCommand updateBillingAddressCommand, + IUpdatePaymentMethodCommand updatePaymentMethodCommand, + IVerifyBankAccountCommand verifyBankAccountCommand) : BaseBillingController +{ + [HttpGet("address")] + [InjectProvider(ProviderUserType.ProviderAdmin)] + public async Task GetBillingAddressAsync( + [BindNever] Provider provider) + { + var billingAddress = await getBillingAddressQuery.Run(provider); + return TypedResults.Ok(billingAddress); + } + + [HttpPut("address")] + [InjectProvider(ProviderUserType.ProviderAdmin)] + public async Task UpdateBillingAddressAsync( + [BindNever] Provider provider, + [FromBody] BillingAddressRequest request) + { + var billingAddress = request.ToDomain(); + var result = await updateBillingAddressCommand.Run(provider, billingAddress); + return Handle(result); + } + + [HttpGet("credit")] + [InjectProvider(ProviderUserType.ProviderAdmin)] + public async Task GetCreditAsync( + [BindNever] Provider provider) + { + var credit = await getCreditQuery.Run(provider); + return TypedResults.Ok(credit); + } + + [HttpPost("credit/bitpay")] + [InjectProvider(ProviderUserType.ProviderAdmin)] + public async Task AddCreditViaBitPayAsync( + [BindNever] Provider provider, + [FromBody] BitPayCreditRequest request) + { + var result = await createBitPayInvoiceForCreditCommand.Run( + provider, + request.Amount, + request.RedirectUrl); + return Handle(result); + } + + [HttpGet("payment-method")] + [InjectProvider(ProviderUserType.ProviderAdmin)] + public async Task GetPaymentMethodAsync( + [BindNever] Provider provider) + { + var paymentMethod = await getPaymentMethodQuery.Run(provider); + return TypedResults.Ok(paymentMethod); + } + + [HttpPut("payment-method")] + [InjectProvider(ProviderUserType.ProviderAdmin)] + public async Task UpdatePaymentMethodAsync( + [BindNever] Provider provider, + [FromBody] TokenizedPaymentMethodRequest request) + { + var (paymentMethod, billingAddress) = request.ToDomain(); + var result = await updatePaymentMethodCommand.Run(provider, paymentMethod, billingAddress); + return Handle(result); + } + + [HttpPost("payment-method/verify-bank-account")] + [InjectProvider(ProviderUserType.ProviderAdmin)] + public async Task VerifyBankAccountAsync( + [BindNever] Provider provider, + [FromBody] VerifyBankAccountRequest request) + { + var result = await verifyBankAccountCommand.Run(provider, request.DescriptorCode); + return Handle(result); + } +} diff --git a/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs new file mode 100644 index 0000000000..5c3c47f585 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs @@ -0,0 +1,20 @@ +#nullable enable +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Payment; + +public record BillingAddressRequest : CheckoutBillingAddressRequest +{ + public string? Line1 { get; set; } + public string? Line2 { get; set; } + public string? City { get; set; } + public string? State { get; set; } + + public override BillingAddress ToDomain() => base.ToDomain() with + { + Line1 = Line1, + Line2 = Line2, + City = City, + State = State, + }; +} diff --git a/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs b/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs new file mode 100644 index 0000000000..bb6e7498d7 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs @@ -0,0 +1,13 @@ +#nullable enable +using System.ComponentModel.DataAnnotations; + +namespace Bit.Api.Billing.Models.Requests.Payment; + +public record BitPayCreditRequest +{ + [Required] + public required decimal Amount { get; set; } + + [Required] + public required string RedirectUrl { get; set; } = null!; +} diff --git a/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs new file mode 100644 index 0000000000..54116e897d --- /dev/null +++ b/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs @@ -0,0 +1,24 @@ +#nullable enable +using System.ComponentModel.DataAnnotations; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Payment; + +public record CheckoutBillingAddressRequest : MinimalBillingAddressRequest +{ + public TaxIdRequest? TaxId { get; set; } + + public override BillingAddress ToDomain() => base.ToDomain() with + { + TaxId = TaxId != null ? new TaxID(TaxId.Code, TaxId.Value) : null + }; + + public class TaxIdRequest + { + [Required] + public string Code { get; set; } = null!; + + [Required] + public string Value { get; set; } = null!; + } +} diff --git a/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs new file mode 100644 index 0000000000..b4d28017d5 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs @@ -0,0 +1,16 @@ +#nullable enable +using System.ComponentModel.DataAnnotations; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Payment; + +public record MinimalBillingAddressRequest +{ + [Required] + [StringLength(2, MinimumLength = 2, ErrorMessage = "Country code must be 2 characters long.")] + public required string Country { get; set; } = null!; + [Required] + public required string PostalCode { get; set; } = null!; + + public virtual BillingAddress ToDomain() => new() { Country = Country, PostalCode = PostalCode, }; +} diff --git a/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs new file mode 100644 index 0000000000..663e4e7cd2 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs @@ -0,0 +1,39 @@ +#nullable enable +using System.ComponentModel.DataAnnotations; +using Bit.Api.Utilities; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Payment; + +public class TokenizedPaymentMethodRequest +{ + [Required] + [StringMatches("bankAccount", "card", "payPal", + ErrorMessage = "Payment method type must be one of: bankAccount, card, payPal")] + public required string Type { get; set; } + + [Required] + public required string Token { get; set; } + + public MinimalBillingAddressRequest? BillingAddress { get; set; } + + public (TokenizedPaymentMethod, BillingAddress?) ToDomain() + { + var paymentMethod = new TokenizedPaymentMethod + { + Type = Type switch + { + "bankAccount" => TokenizablePaymentMethodType.BankAccount, + "card" => TokenizablePaymentMethodType.Card, + "payPal" => TokenizablePaymentMethodType.PayPal, + _ => throw new InvalidOperationException( + $"Invalid value for {nameof(TokenizedPaymentMethod)}.{nameof(TokenizedPaymentMethod.Type)}") + }, + Token = Token + }; + + var billingAddress = BillingAddress?.ToDomain(); + + return (paymentMethod, billingAddress); + } +} diff --git a/src/Api/Billing/Models/Requests/Payment/VerifyBankAccountRequest.cs b/src/Api/Billing/Models/Requests/Payment/VerifyBankAccountRequest.cs new file mode 100644 index 0000000000..2b5d6a0cb1 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Payment/VerifyBankAccountRequest.cs @@ -0,0 +1,9 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Api.Billing.Models.Requests.Payment; + +public class VerifyBankAccountRequest +{ + [Required] + public required string DescriptorCode { get; set; } +} diff --git a/src/Api/Billing/Models/Requirements/ManageOrganizationBillingRequirement.cs b/src/Api/Billing/Models/Requirements/ManageOrganizationBillingRequirement.cs new file mode 100644 index 0000000000..4efdf0812a --- /dev/null +++ b/src/Api/Billing/Models/Requirements/ManageOrganizationBillingRequirement.cs @@ -0,0 +1,18 @@ +#nullable enable +using Bit.Api.AdminConsole.Authorization; +using Bit.Core.Context; +using Bit.Core.Enums; + +namespace Bit.Api.Billing.Models.Requirements; + +public class ManageOrganizationBillingRequirement : IOrganizationRequirement +{ + public async Task AuthorizeAsync( + CurrentContextOrganization? organizationClaims, + Func> isProviderUserForOrg) + => organizationClaims switch + { + { Type: OrganizationUserType.Owner } => true, + _ => await isProviderUserForOrg() + }; +} diff --git a/src/Api/Utilities/StringMatchesAttribute.cs b/src/Api/Utilities/StringMatchesAttribute.cs new file mode 100644 index 0000000000..28485aed40 --- /dev/null +++ b/src/Api/Utilities/StringMatchesAttribute.cs @@ -0,0 +1,18 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Api.Utilities; + +public class StringMatchesAttribute(params string[]? accepted) : ValidationAttribute +{ + public override bool IsValid(object? value) + { + if (value is not string str || + accepted == null || + accepted.Length == 0) + { + return false; + } + + return accepted.Contains(str); + } +} diff --git a/src/Core/Billing/Commands/BillingCommand.cs b/src/Core/Billing/Commands/BillingCommand.cs new file mode 100644 index 0000000000..e6c6375b62 --- /dev/null +++ b/src/Core/Billing/Commands/BillingCommand.cs @@ -0,0 +1,62 @@ +using Bit.Core.Billing.Constants; +using Microsoft.Extensions.Logging; +using Stripe; + +namespace Bit.Core.Billing.Commands; + +using static StripeConstants; + +public abstract class BillingCommand( + ILogger logger) +{ + protected string CommandName => GetType().Name; + + /// + /// Executes the provided function within a predefined execution context, handling any exceptions that occur during the process. + /// + /// The type of the successful result expected from the provided function. + /// A function that performs an operation and returns a . + /// A task that represents the operation. The result provides a which may indicate success or an error outcome. + protected async Task> HandleAsync( + Func>> function) + { + try + { + return await function(); + } + catch (StripeException stripeException) when (ErrorCodes.Get().Contains(stripeException.StripeError.Code)) + { + return stripeException.StripeError.Code switch + { + ErrorCodes.CustomerTaxLocationInvalid => + new BadRequest("Your location wasn't recognized. Please ensure your country and postal code are valid and try again."), + + ErrorCodes.PaymentMethodMicroDepositVerificationAttemptsExceeded => + new BadRequest("You have exceeded the number of allowed verification attempts. Please contact support for assistance."), + + ErrorCodes.PaymentMethodMicroDepositVerificationDescriptorCodeMismatch => + new BadRequest("The verification code you provided does not match the one sent to your bank account. Please try again."), + + ErrorCodes.PaymentMethodMicroDepositVerificationTimeout => + new BadRequest("Your bank account was not verified within the required time period. Please contact support for assistance."), + + ErrorCodes.TaxIdInvalid => + new BadRequest("The tax ID number you provided was invalid. Please try again or contact support for assistance."), + + _ => new Unhandled(stripeException) + }; + } + catch (StripeException stripeException) + { + logger.LogError(stripeException, + "{Command}: An error occurred while communicating with Stripe | Code = {Code}", CommandName, + stripeException.StripeError.Code); + return new Unhandled(stripeException); + } + catch (Exception exception) + { + logger.LogError(exception, "{Command}: An unknown error occurred during execution", CommandName); + return new Unhandled(exception); + } + } +} diff --git a/src/Core/Billing/Commands/BillingCommandResult.cs b/src/Core/Billing/Commands/BillingCommandResult.cs new file mode 100644 index 0000000000..b69ad4bf12 --- /dev/null +++ b/src/Core/Billing/Commands/BillingCommandResult.cs @@ -0,0 +1,31 @@ +#nullable enable +using OneOf; + +namespace Bit.Core.Billing.Commands; + +public record BadRequest(string Response); +public record Conflict(string Response); +public record Unhandled(Exception? Exception = null, string Response = "Something went wrong with your request. Please contact support for assistance."); + +/// +/// A union type representing the result of a billing command. +/// +/// Choices include: +/// +/// : Success +/// : Invalid input +/// : A known, but unresolvable issue +/// : An unknown issue +/// +/// +/// +/// The successful result type of the operation. +public class BillingCommandResult : OneOfBase +{ + private BillingCommandResult(OneOf input) : base(input) { } + + public static implicit operator BillingCommandResult(T output) => new(output); + public static implicit operator BillingCommandResult(BadRequest badRequest) => new(badRequest); + public static implicit operator BillingCommandResult(Conflict conflict) => new(conflict); + public static implicit operator BillingCommandResult(Unhandled unhandled) => new(unhandled); +} diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index 0cffad72d3..3aaa519d66 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -1,4 +1,6 @@ -namespace Bit.Core.Billing.Constants; +using System.Reflection; + +namespace Bit.Core.Billing.Constants; public static class StripeConstants { @@ -36,6 +38,13 @@ public static class StripeConstants public const string PaymentMethodMicroDepositVerificationDescriptorCodeMismatch = "payment_method_microdeposit_verification_descriptor_code_mismatch"; public const string PaymentMethodMicroDepositVerificationTimeout = "payment_method_microdeposit_verification_timeout"; public const string TaxIdInvalid = "tax_id_invalid"; + + public static string[] Get() => + typeof(ErrorCodes) + .GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy) + .Where(fi => fi is { IsLiteral: true, IsInitOnly: false } && fi.FieldType == typeof(string)) + .Select(fi => (string)fi.GetValue(null)!) + .ToArray(); } public static class InvoiceStatus @@ -51,6 +60,7 @@ public static class StripeConstants public const string InvoiceApproved = "invoice_approved"; public const string OrganizationId = "organizationId"; public const string ProviderId = "providerId"; + public const string RetiredBraintreeCustomerId = "btCustomerId_old"; public const string UserId = "userId"; } diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 5c7a42e9b8..5f1a0668bd 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -1,6 +1,7 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Caches.Implementations; using Bit.Core.Billing.Licenses.Extensions; +using Bit.Core.Billing.Payment; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; @@ -27,5 +28,6 @@ public static class ServiceCollectionExtensions services.AddLicenseServices(); services.AddPricingClient(); services.AddTransient(); + services.AddPaymentOperations(); } } diff --git a/src/Core/Billing/Extensions/SubscriberExtensions.cs b/src/Core/Billing/Extensions/SubscriberExtensions.cs index e322ed7317..fc804de224 100644 --- a/src/Core/Billing/Extensions/SubscriberExtensions.cs +++ b/src/Core/Billing/Extensions/SubscriberExtensions.cs @@ -1,4 +1,8 @@ -using Bit.Core.Entities; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Entities; namespace Bit.Core.Billing.Extensions; @@ -23,4 +27,14 @@ public static class SubscriberExtensions ? subscriberName : subscriberName[..30]; } + + public static ProductUsageType GetProductUsageType(this ISubscriber subscriber) + => subscriber switch + { + User => ProductUsageType.Personal, + Organization organization when organization.PlanType.GetProductTier() is ProductTierType.Free or ProductTierType.Families => ProductUsageType.Personal, + Organization => ProductUsageType.Business, + Provider => ProductUsageType.Business, + _ => throw new ArgumentOutOfRangeException(nameof(subscriber)) + }; } diff --git a/src/Core/Billing/Models/BillingCommandResult.cs b/src/Core/Billing/Models/BillingCommandResult.cs deleted file mode 100644 index 1b8eefe8df..0000000000 --- a/src/Core/Billing/Models/BillingCommandResult.cs +++ /dev/null @@ -1,36 +0,0 @@ -using OneOf; - -namespace Bit.Core.Billing.Models; - -public record BadRequest(string TranslationKey) -{ - public static BadRequest TaxIdNumberInvalid => new(BillingErrorTranslationKeys.TaxIdInvalid); - public static BadRequest TaxLocationInvalid => new(BillingErrorTranslationKeys.CustomerTaxLocationInvalid); - public static BadRequest UnknownTaxIdType => new(BillingErrorTranslationKeys.UnknownTaxIdType); -} - -public record Unhandled(string TranslationKey = BillingErrorTranslationKeys.UnhandledError); - -public class BillingCommandResult : OneOfBase -{ - private BillingCommandResult(OneOf input) : base(input) { } - - public static implicit operator BillingCommandResult(T output) => new(output); - public static implicit operator BillingCommandResult(BadRequest badRequest) => new(badRequest); - public static implicit operator BillingCommandResult(Unhandled unhandled) => new(unhandled); -} - -public static class BillingErrorTranslationKeys -{ - // "The tax ID number you provided was invalid. Please try again or contact support." - public const string TaxIdInvalid = "taxIdInvalid"; - - // "Your location wasn't recognized. Please ensure your country and postal code are valid and try again." - public const string CustomerTaxLocationInvalid = "customerTaxLocationInvalid"; - - // "Something went wrong with your request. Please contact support." - public const string UnhandledError = "unhandledBillingError"; - - // "We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support." - public const string UnknownTaxIdType = "unknownTaxIdType"; -} diff --git a/src/Core/Billing/Payment/Clients/BitPayClient.cs b/src/Core/Billing/Payment/Clients/BitPayClient.cs new file mode 100644 index 0000000000..2cb8fb66ef --- /dev/null +++ b/src/Core/Billing/Payment/Clients/BitPayClient.cs @@ -0,0 +1,24 @@ +using Bit.Core.Settings; +using BitPayLight; +using BitPayLight.Models.Invoice; + +namespace Bit.Core.Billing.Payment.Clients; + +public interface IBitPayClient +{ + Task GetInvoice(string invoiceId); + Task CreateInvoice(Invoice invoice); +} + +public class BitPayClient( + GlobalSettings globalSettings) : IBitPayClient +{ + private readonly BitPay _bitPay = new( + globalSettings.BitPay.Token, globalSettings.BitPay.Production ? Env.Prod : Env.Test); + + public Task GetInvoice(string invoiceId) + => _bitPay.GetInvoice(invoiceId); + + public Task CreateInvoice(Invoice invoice) + => _bitPay.CreateInvoice(invoice); +} diff --git a/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs b/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs new file mode 100644 index 0000000000..f61fa9d0f9 --- /dev/null +++ b/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs @@ -0,0 +1,59 @@ +#nullable enable +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Payment.Clients; +using Bit.Core.Entities; +using Bit.Core.Settings; +using BitPayLight.Models.Invoice; +using Microsoft.Extensions.Logging; + +namespace Bit.Core.Billing.Payment.Commands; + +public interface ICreateBitPayInvoiceForCreditCommand +{ + Task> Run( + ISubscriber subscriber, + decimal amount, + string redirectUrl); +} + +public class CreateBitPayInvoiceForCreditCommand( + IBitPayClient bitPayClient, + GlobalSettings globalSettings, + ILogger logger) : BillingCommand(logger), ICreateBitPayInvoiceForCreditCommand +{ + public Task> Run( + ISubscriber subscriber, + decimal amount, + string redirectUrl) => HandleAsync(async () => + { + var (name, email, posData) = GetSubscriberInformation(subscriber); + + var invoice = new Invoice + { + Buyer = new Buyer { Email = email, Name = name }, + Currency = "USD", + ExtendedNotifications = true, + FullNotifications = true, + ItemDesc = "Bitwarden", + NotificationUrl = globalSettings.BitPay.NotificationUrl, + PosData = posData, + Price = Convert.ToDouble(amount), + RedirectUrl = redirectUrl + }; + + var created = await bitPayClient.CreateInvoice(invoice); + return created.Url; + }); + + private static (string? Name, string? Email, string POSData) GetSubscriberInformation( + ISubscriber subscriber) => subscriber switch + { + User user => (user.Email, user.Email, $"userId:{user.Id},accountCredit:1"), + Organization organization => (organization.Name, organization.BillingEmail, + $"organizationId:{organization.Id},accountCredit:1"), + Provider provider => (provider.Name, provider.BillingEmail, $"providerId:{provider.Id},accountCredit:1"), + _ => throw new ArgumentOutOfRangeException(nameof(subscriber)) + }; +} diff --git a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs new file mode 100644 index 0000000000..adc534bd7d --- /dev/null +++ b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs @@ -0,0 +1,129 @@ +#nullable enable +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Entities; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using Stripe; + +namespace Bit.Core.Billing.Payment.Commands; + +public interface IUpdateBillingAddressCommand +{ + Task> Run( + ISubscriber subscriber, + BillingAddress billingAddress); +} + +public class UpdateBillingAddressCommand( + ILogger logger, + IStripeAdapter stripeAdapter) : BillingCommand(logger), IUpdateBillingAddressCommand +{ + public Task> Run( + ISubscriber subscriber, + BillingAddress billingAddress) => HandleAsync(() => subscriber.GetProductUsageType() switch + { + ProductUsageType.Personal => UpdatePersonalBillingAddressAsync(subscriber, billingAddress), + ProductUsageType.Business => UpdateBusinessBillingAddressAsync(subscriber, billingAddress) + }); + + private async Task> UpdatePersonalBillingAddressAsync( + ISubscriber subscriber, + BillingAddress billingAddress) + { + var customer = + await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, + new CustomerUpdateOptions + { + Address = new AddressOptions + { + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode, + Line1 = billingAddress.Line1, + Line2 = billingAddress.Line2, + City = billingAddress.City, + State = billingAddress.State + }, + Expand = ["subscriptions"] + }); + + await EnableAutomaticTaxAsync(subscriber, customer); + + return BillingAddress.From(customer.Address); + } + + private async Task> UpdateBusinessBillingAddressAsync( + ISubscriber subscriber, + BillingAddress billingAddress) + { + var customer = + await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, + new CustomerUpdateOptions + { + Address = new AddressOptions + { + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode, + Line1 = billingAddress.Line1, + Line2 = billingAddress.Line2, + City = billingAddress.City, + State = billingAddress.State + }, + Expand = ["subscriptions", "tax_ids"], + TaxExempt = billingAddress.Country != "US" + ? StripeConstants.TaxExempt.Reverse + : StripeConstants.TaxExempt.None + }); + + await EnableAutomaticTaxAsync(subscriber, customer); + + var deleteExistingTaxIds = customer.TaxIds?.Any() ?? false + ? customer.TaxIds.Select(taxId => stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id)).ToList() + : []; + + if (billingAddress.TaxId == null) + { + await Task.WhenAll(deleteExistingTaxIds); + return BillingAddress.From(customer.Address); + } + + var updatedTaxId = await stripeAdapter.TaxIdCreateAsync(customer.Id, + new TaxIdCreateOptions { Type = billingAddress.TaxId.Code, Value = billingAddress.TaxId.Value }); + + if (billingAddress.TaxId.Code == StripeConstants.TaxIdType.SpanishNIF) + { + updatedTaxId = await stripeAdapter.TaxIdCreateAsync(customer.Id, + new TaxIdCreateOptions + { + Type = StripeConstants.TaxIdType.EUVAT, + Value = $"ES{billingAddress.TaxId.Value}" + }); + } + + await Task.WhenAll(deleteExistingTaxIds); + + return BillingAddress.From(customer.Address, updatedTaxId); + } + + private async Task EnableAutomaticTaxAsync( + ISubscriber subscriber, + Customer customer) + { + if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + var subscription = customer.Subscriptions.FirstOrDefault(subscription => + subscription.Id == subscriber.GatewaySubscriptionId); + + if (subscription is { AutomaticTax.Enabled: false }) + { + await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + } + } +} diff --git a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs new file mode 100644 index 0000000000..cda685d520 --- /dev/null +++ b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs @@ -0,0 +1,205 @@ +#nullable enable +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Braintree; +using Microsoft.Extensions.Logging; +using Stripe; +using Customer = Stripe.Customer; + +namespace Bit.Core.Billing.Payment.Commands; + +public interface IUpdatePaymentMethodCommand +{ + Task> Run( + ISubscriber subscriber, + TokenizedPaymentMethod paymentMethod, + BillingAddress? billingAddress); +} + +public class UpdatePaymentMethodCommand( + IBraintreeGateway braintreeGateway, + IGlobalSettings globalSettings, + ILogger logger, + ISetupIntentCache setupIntentCache, + IStripeAdapter stripeAdapter, + ISubscriberService subscriberService) : BillingCommand(logger), IUpdatePaymentMethodCommand +{ + private readonly ILogger _logger = logger; + private static readonly Conflict _conflict = new("We had a problem updating your payment method. Please contact support for assistance."); + + public Task> Run( + ISubscriber subscriber, + TokenizedPaymentMethod paymentMethod, + BillingAddress? billingAddress) => HandleAsync(async () => + { + var customer = await subscriberService.GetCustomer(subscriber); + + var result = paymentMethod.Type switch + { + TokenizablePaymentMethodType.BankAccount => await AddBankAccountAsync(subscriber, customer, paymentMethod.Token), + TokenizablePaymentMethodType.Card => await AddCardAsync(customer, paymentMethod.Token), + TokenizablePaymentMethodType.PayPal => await AddPayPalAsync(subscriber, customer, paymentMethod.Token), + _ => new BadRequest($"Payment method type '{paymentMethod.Type}' is not supported.") + }; + + if (billingAddress != null && customer.Address is not { Country: not null, PostalCode: not null }) + { + await stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions + { + Address = new AddressOptions + { + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode + } + }); + } + + return result; + }); + + private async Task> AddBankAccountAsync( + ISubscriber subscriber, + Customer customer, + string token) + { + var setupIntents = await stripeAdapter.SetupIntentList(new SetupIntentListOptions + { + Expand = ["data.payment_method"], + PaymentMethod = token + }); + + switch (setupIntents.Count) + { + case 0: + _logger.LogError("{Command}: Could not find setup intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id); + return _conflict; + case > 1: + _logger.LogError("{Command}: Found more than one set up intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id); + return _conflict; + } + + var setupIntent = setupIntents.First(); + + await setupIntentCache.Set(subscriber.Id, setupIntent.Id); + + await UnlinkBraintreeCustomerAsync(customer); + + return MaskedPaymentMethod.From(setupIntent); + } + + private async Task> AddCardAsync( + Customer customer, + string token) + { + var paymentMethod = await stripeAdapter.PaymentMethodAttachAsync(token, new PaymentMethodAttachOptions { Customer = customer.Id }); + + await stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions + { + InvoiceSettings = new CustomerInvoiceSettingsOptions { DefaultPaymentMethod = token } + }); + + await UnlinkBraintreeCustomerAsync(customer); + + return MaskedPaymentMethod.From(paymentMethod.Card); + } + + private async Task> AddPayPalAsync( + ISubscriber subscriber, + Customer customer, + string token) + { + Braintree.Customer braintreeCustomer; + + if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId)) + { + braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); + + await ReplaceBraintreePaymentMethodAsync(braintreeCustomer, token); + } + else + { + braintreeCustomer = await CreateBraintreeCustomerAsync(subscriber, token); + + var metadata = new Dictionary + { + [StripeConstants.MetadataKeys.BraintreeCustomerId] = braintreeCustomer.Id + }; + + await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); + } + + var payPalAccount = braintreeCustomer.DefaultPaymentMethod as PayPalAccount; + + return MaskedPaymentMethod.From(payPalAccount!); + } + + private async Task CreateBraintreeCustomerAsync( + ISubscriber subscriber, + string token) + { + var braintreeCustomerId = + subscriber.BraintreeCustomerIdPrefix() + + subscriber.Id.ToString("N").ToLower() + + CoreHelpers.RandomString(3, upper: false, numeric: false); + + var result = await braintreeGateway.Customer.CreateAsync(new CustomerRequest + { + Id = braintreeCustomerId, + CustomFields = new Dictionary + { + [subscriber.BraintreeIdField()] = subscriber.Id.ToString(), + [subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion + }, + Email = subscriber.BillingEmailAddress(), + PaymentMethodNonce = token + }); + + return result.Target; + } + + private async Task ReplaceBraintreePaymentMethodAsync( + Braintree.Customer customer, + string token) + { + var existing = customer.DefaultPaymentMethod; + + var result = await braintreeGateway.PaymentMethod.CreateAsync(new PaymentMethodRequest + { + CustomerId = customer.Id, + PaymentMethodNonce = token + }); + + await braintreeGateway.Customer.UpdateAsync( + customer.Id, + new CustomerRequest { DefaultPaymentMethodToken = result.Target.Token }); + + if (existing != null) + { + await braintreeGateway.PaymentMethod.DeleteAsync(existing.Token); + } + } + + private async Task UnlinkBraintreeCustomerAsync( + Customer customer) + { + if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId)) + { + var metadata = new Dictionary + { + [StripeConstants.MetadataKeys.RetiredBraintreeCustomerId] = braintreeCustomerId, + [StripeConstants.MetadataKeys.BraintreeCustomerId] = string.Empty + }; + + await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); + } + } +} diff --git a/src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs b/src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs new file mode 100644 index 0000000000..1e9492b876 --- /dev/null +++ b/src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs @@ -0,0 +1,63 @@ +#nullable enable +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Entities; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using Stripe; + +namespace Bit.Core.Billing.Payment.Commands; + +public interface IVerifyBankAccountCommand +{ + Task> Run( + ISubscriber subscriber, + string descriptorCode); +} + +public class VerifyBankAccountCommand( + ILogger logger, + ISetupIntentCache setupIntentCache, + IStripeAdapter stripeAdapter) : BillingCommand(logger), IVerifyBankAccountCommand +{ + private readonly ILogger _logger = logger; + + private static readonly Conflict _conflict = + new("We had a problem verifying your bank account. Please contact support for assistance."); + + public Task> Run( + ISubscriber subscriber, + string descriptorCode) => HandleAsync(async () => + { + var setupIntentId = await setupIntentCache.Get(subscriber.Id); + + if (string.IsNullOrEmpty(setupIntentId)) + { + _logger.LogError( + "{Command}: Could not find setup intent to verify subscriber's ({SubscriberID}) bank account", + CommandName, subscriber.Id); + return _conflict; + } + + await stripeAdapter.SetupIntentVerifyMicroDeposit(setupIntentId, + new SetupIntentVerifyMicrodepositsOptions { DescriptorCode = descriptorCode }); + + var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, + new SetupIntentGetOptions { Expand = ["payment_method"] }); + + var paymentMethod = await stripeAdapter.PaymentMethodAttachAsync(setupIntent.PaymentMethodId, + new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId }); + + await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, + new CustomerUpdateOptions + { + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + DefaultPaymentMethod = setupIntent.PaymentMethodId + } + }); + + return MaskedPaymentMethod.From(paymentMethod.UsBankAccount); + }); +} diff --git a/src/Core/Billing/Payment/Models/BillingAddress.cs b/src/Core/Billing/Payment/Models/BillingAddress.cs new file mode 100644 index 0000000000..5c2c43231c --- /dev/null +++ b/src/Core/Billing/Payment/Models/BillingAddress.cs @@ -0,0 +1,30 @@ +#nullable enable +using Stripe; + +namespace Bit.Core.Billing.Payment.Models; + +public record TaxID(string Code, string Value); + +public record BillingAddress +{ + public required string Country { get; set; } + public required string PostalCode { get; set; } + public string? Line1 { get; set; } + public string? Line2 { get; set; } + public string? City { get; set; } + public string? State { get; set; } + public TaxID? TaxId { get; set; } + + public static BillingAddress From(Address address) => new() + { + Country = address.Country, + PostalCode = address.PostalCode, + Line1 = address.Line1, + Line2 = address.Line2, + City = address.City, + State = address.State + }; + + public static BillingAddress From(Address address, TaxId? taxId) => + From(address) with { TaxId = taxId != null ? new TaxID(taxId.Type, taxId.Value) : null }; +} diff --git a/src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs b/src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs new file mode 100644 index 0000000000..c98fddc785 --- /dev/null +++ b/src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs @@ -0,0 +1,120 @@ +#nullable enable +using System.Text.Json; +using System.Text.Json.Serialization; +using Bit.Core.Billing.Pricing.JSON; +using Braintree; +using OneOf; +using Stripe; + +namespace Bit.Core.Billing.Payment.Models; + +public record MaskedBankAccount +{ + public required string BankName { get; init; } + public required string Last4 { get; init; } + public required bool Verified { get; init; } + public string Type => "bankAccount"; +} + +public record MaskedCard +{ + public required string Brand { get; init; } + public required string Last4 { get; init; } + public required string Expiration { get; init; } + public string Type => "card"; +} + +public record MaskedPayPalAccount +{ + public required string Email { get; init; } + public string Type => "payPal"; +} + +[JsonConverter(typeof(MaskedPaymentMethodJsonConverter))] +public class MaskedPaymentMethod(OneOf input) + : OneOfBase(input) +{ + public static implicit operator MaskedPaymentMethod(MaskedBankAccount bankAccount) => new(bankAccount); + public static implicit operator MaskedPaymentMethod(MaskedCard card) => new(card); + public static implicit operator MaskedPaymentMethod(MaskedPayPalAccount payPal) => new(payPal); + + public static MaskedPaymentMethod From(BankAccount bankAccount) => new MaskedBankAccount + { + BankName = bankAccount.BankName, + Last4 = bankAccount.Last4, + Verified = bankAccount.Status == "verified" + }; + + public static MaskedPaymentMethod From(Card card) => new MaskedCard + { + Brand = card.Brand.ToLower(), + Last4 = card.Last4, + Expiration = $"{card.ExpMonth:00}/{card.ExpYear}" + }; + + public static MaskedPaymentMethod From(PaymentMethodCard card) => new MaskedCard + { + Brand = card.Brand.ToLower(), + Last4 = card.Last4, + Expiration = $"{card.ExpMonth:00}/{card.ExpYear}" + }; + + public static MaskedPaymentMethod From(SetupIntent setupIntent) => new MaskedBankAccount + { + BankName = setupIntent.PaymentMethod.UsBankAccount.BankName, + Last4 = setupIntent.PaymentMethod.UsBankAccount.Last4, + Verified = false + }; + + public static MaskedPaymentMethod From(SourceCard sourceCard) => new MaskedCard + { + Brand = sourceCard.Brand.ToLower(), + Last4 = sourceCard.Last4, + Expiration = $"{sourceCard.ExpMonth:00}/{sourceCard.ExpYear}" + }; + + public static MaskedPaymentMethod From(PaymentMethodUsBankAccount bankAccount) => new MaskedBankAccount + { + BankName = bankAccount.BankName, + Last4 = bankAccount.Last4, + Verified = true + }; + + public static MaskedPaymentMethod From(PayPalAccount payPalAccount) => new MaskedPayPalAccount { Email = payPalAccount.Email }; +} + +public class MaskedPaymentMethodJsonConverter : TypeReadingJsonConverter +{ + protected override string TypePropertyName => nameof(MaskedBankAccount.Type).ToLower(); + + public override MaskedPaymentMethod? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var type = ReadType(reader); + + return type switch + { + "bankAccount" => JsonSerializer.Deserialize(ref reader, options) switch + { + null => null, + var bankAccount => new MaskedPaymentMethod(bankAccount) + }, + "card" => JsonSerializer.Deserialize(ref reader, options) switch + { + null => null, + var card => new MaskedPaymentMethod(card) + }, + "payPal" => JsonSerializer.Deserialize(ref reader, options) switch + { + null => null, + var payPal => new MaskedPaymentMethod(payPal) + }, + _ => Skip(ref reader) + }; + } + + public override void Write(Utf8JsonWriter writer, MaskedPaymentMethod value, JsonSerializerOptions options) + => value.Switch( + bankAccount => JsonSerializer.Serialize(writer, bankAccount, options), + card => JsonSerializer.Serialize(writer, card, options), + payPal => JsonSerializer.Serialize(writer, payPal, options)); +} diff --git a/src/Core/Billing/Payment/Models/ProductUsageType.cs b/src/Core/Billing/Payment/Models/ProductUsageType.cs new file mode 100644 index 0000000000..2ecd1233c6 --- /dev/null +++ b/src/Core/Billing/Payment/Models/ProductUsageType.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.Billing.Payment.Models; + +public enum ProductUsageType +{ + Personal, + Business +} diff --git a/src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs b/src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs new file mode 100644 index 0000000000..d27a924360 --- /dev/null +++ b/src/Core/Billing/Payment/Models/TokenizablePaymentMethodType.cs @@ -0,0 +1,8 @@ +namespace Bit.Core.Billing.Payment.Models; + +public enum TokenizablePaymentMethodType +{ + BankAccount, + Card, + PayPal +} diff --git a/src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs b/src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs new file mode 100644 index 0000000000..edbf1bb121 --- /dev/null +++ b/src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs @@ -0,0 +1,8 @@ +#nullable enable +namespace Bit.Core.Billing.Payment.Models; + +public record TokenizedPaymentMethod +{ + public required TokenizablePaymentMethodType Type { get; set; } + public required string Token { get; set; } +} diff --git a/src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs b/src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs new file mode 100644 index 0000000000..84d4d4f377 --- /dev/null +++ b/src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs @@ -0,0 +1,41 @@ +#nullable enable +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Stripe; + +namespace Bit.Core.Billing.Payment.Queries; + +public interface IGetBillingAddressQuery +{ + Task Run(ISubscriber subscriber); +} + +public class GetBillingAddressQuery( + ISubscriberService subscriberService) : IGetBillingAddressQuery +{ + public async Task Run(ISubscriber subscriber) + { + var productUsageType = subscriber.GetProductUsageType(); + + var options = productUsageType switch + { + ProductUsageType.Business => new CustomerGetOptions { Expand = ["tax_ids"] }, + _ => new CustomerGetOptions() + }; + + var customer = await subscriberService.GetCustomer(subscriber, options); + + if (customer is not { Address: { Country: not null, PostalCode: not null } }) + { + return null; + } + + var taxId = productUsageType == ProductUsageType.Business ? customer.TaxIds?.FirstOrDefault() : null; + + return taxId != null + ? BillingAddress.From(customer.Address, taxId) + : BillingAddress.From(customer.Address); + } +} diff --git a/src/Core/Billing/Payment/Queries/GetCreditQuery.cs b/src/Core/Billing/Payment/Queries/GetCreditQuery.cs new file mode 100644 index 0000000000..79c9a13aba --- /dev/null +++ b/src/Core/Billing/Payment/Queries/GetCreditQuery.cs @@ -0,0 +1,26 @@ +#nullable enable +using Bit.Core.Billing.Services; +using Bit.Core.Entities; + +namespace Bit.Core.Billing.Payment.Queries; + +public interface IGetCreditQuery +{ + Task Run(ISubscriber subscriber); +} + +public class GetCreditQuery( + ISubscriberService subscriberService) : IGetCreditQuery +{ + public async Task Run(ISubscriber subscriber) + { + var customer = await subscriberService.GetCustomer(subscriber); + + if (customer == null) + { + return null; + } + + return Convert.ToDecimal(customer.Balance) * -1 / 100; + } +} diff --git a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs new file mode 100644 index 0000000000..eb42a8c78a --- /dev/null +++ b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs @@ -0,0 +1,96 @@ +#nullable enable +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Services; +using Braintree; +using Microsoft.Extensions.Logging; +using Stripe; + +namespace Bit.Core.Billing.Payment.Queries; + +public interface IGetPaymentMethodQuery +{ + Task Run(ISubscriber subscriber); +} + +public class GetPaymentMethodQuery( + IBraintreeGateway braintreeGateway, + ILogger logger, + ISetupIntentCache setupIntentCache, + IStripeAdapter stripeAdapter, + ISubscriberService subscriberService) : IGetPaymentMethodQuery +{ + public async Task Run(ISubscriber subscriber) + { + var customer = await subscriberService.GetCustomer(subscriber, + new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] }); + + if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId)) + { + var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); + + if (braintreeCustomer.DefaultPaymentMethod is PayPalAccount payPalAccount) + { + return new MaskedPayPalAccount { Email = payPalAccount.Email }; + } + + logger.LogWarning("Subscriber ({SubscriberID}) has a linked Braintree customer ({BraintreeCustomerId}) with no PayPal account.", subscriber.Id, braintreeCustomerId); + + return null; + } + + var paymentMethod = customer.InvoiceSettings.DefaultPaymentMethod != null + ? customer.InvoiceSettings.DefaultPaymentMethod.Type switch + { + "card" => MaskedPaymentMethod.From(customer.InvoiceSettings.DefaultPaymentMethod.Card), + "us_bank_account" => MaskedPaymentMethod.From(customer.InvoiceSettings.DefaultPaymentMethod.UsBankAccount), + _ => null + } + : null; + + if (paymentMethod != null) + { + return paymentMethod; + } + + if (customer.DefaultSource != null) + { + paymentMethod = customer.DefaultSource switch + { + Card card => MaskedPaymentMethod.From(card), + BankAccount bankAccount => MaskedPaymentMethod.From(bankAccount), + Source { Card: not null } source => MaskedPaymentMethod.From(source.Card), + _ => null + }; + + if (paymentMethod != null) + { + return paymentMethod; + } + } + + var setupIntentId = await setupIntentCache.Get(subscriber.Id); + + if (string.IsNullOrEmpty(setupIntentId)) + { + return null; + } + + var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + { + Expand = ["payment_method"] + }); + + // ReSharper disable once ConvertIfStatementToReturnStatement + if (!setupIntent.IsUnverifiedBankAccount()) + { + return null; + } + + return MaskedPaymentMethod.From(setupIntent); + } +} diff --git a/src/Core/Billing/Payment/Registrations.cs b/src/Core/Billing/Payment/Registrations.cs new file mode 100644 index 0000000000..1cc7914f10 --- /dev/null +++ b/src/Core/Billing/Payment/Registrations.cs @@ -0,0 +1,24 @@ +using Bit.Core.Billing.Payment.Clients; +using Bit.Core.Billing.Payment.Commands; +using Bit.Core.Billing.Payment.Queries; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Core.Billing.Payment; + +public static class Registrations +{ + public static void AddPaymentOperations(this IServiceCollection services) + { + // Commands + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + + // Queries + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + } +} diff --git a/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs b/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs index ef8d33304e..05beccdb60 100644 --- a/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs +++ b/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs @@ -6,7 +6,7 @@ namespace Bit.Core.Billing.Pricing.JSON; #nullable enable -public abstract class TypeReadingJsonConverter : JsonConverter +public abstract class TypeReadingJsonConverter : JsonConverter where T : class { protected virtual string TypePropertyName => nameof(ScalableDTO.Type).ToLower(); @@ -14,7 +14,9 @@ public abstract class TypeReadingJsonConverter : JsonConverter { while (reader.Read()) { - if (reader.TokenType != JsonTokenType.PropertyName || reader.GetString()?.ToLower() != TypePropertyName) + if (reader.CurrentDepth != 1 || + reader.TokenType != JsonTokenType.PropertyName || + reader.GetString()?.ToLower() != TypePropertyName) { continue; } @@ -25,4 +27,10 @@ public abstract class TypeReadingJsonConverter : JsonConverter return null; } + + protected T? Skip(ref Utf8JsonReader reader) + { + reader.Skip(); + return null; + } } diff --git a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs b/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs index c777d0c0d1..86f233232f 100644 --- a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs +++ b/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs @@ -1,8 +1,8 @@ #nullable enable +using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Tax.Services; using Bit.Core.Services; @@ -20,111 +20,95 @@ public class PreviewTaxAmountCommand( ILogger logger, IPricingClient pricingClient, IStripeAdapter stripeAdapter, - ITaxService taxService) : IPreviewTaxAmountCommand + ITaxService taxService) : BillingCommand(logger), IPreviewTaxAmountCommand { - public async Task> Run(OrganizationTrialParameters parameters) - { - var (planType, productType, taxInformation) = parameters; - - var plan = await pricingClient.GetPlanOrThrow(planType); - - var options = new InvoiceCreatePreviewOptions + public Task> Run(OrganizationTrialParameters parameters) + => HandleAsync(async () => { - Currency = "usd", - CustomerDetails = new InvoiceCustomerDetailsOptions + var (planType, productType, taxInformation) = parameters; + + var plan = await pricingClient.GetPlanOrThrow(planType); + + var options = new InvoiceCreatePreviewOptions { - Address = new AddressOptions + Currency = "usd", + CustomerDetails = new InvoiceCustomerDetailsOptions { - Country = taxInformation.Country, - PostalCode = taxInformation.PostalCode - } - }, - SubscriptionDetails = new InvoiceSubscriptionDetailsOptions - { - Items = [ - new InvoiceSubscriptionDetailsItemOptions + Address = new AddressOptions { - Price = plan.HasNonSeatBasedPasswordManagerPlan() ? plan.PasswordManager.StripePlanId : plan.PasswordManager.StripeSeatPlanId, - Quantity = 1 + Country = taxInformation.Country, + PostalCode = taxInformation.PostalCode } - ] - } - }; - - if (productType == ProductType.SecretsManager) - { - options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions - { - Price = plan.SecretsManager.StripeSeatPlanId, - Quantity = 1 - }); - - options.Coupon = StripeConstants.CouponIDs.SecretsManagerStandalone; - } - - if (!string.IsNullOrEmpty(taxInformation.TaxId)) - { - var taxIdType = taxService.GetStripeTaxCode( - taxInformation.Country, - taxInformation.TaxId); - - if (string.IsNullOrEmpty(taxIdType)) - { - return BadRequest.UnknownTaxIdType; - } - - options.CustomerDetails.TaxIds = [ - new InvoiceCustomerDetailsTaxIdOptions + }, + SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { - Type = taxIdType, - Value = taxInformation.TaxId + Items = + [ + new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.HasNonSeatBasedPasswordManagerPlan() + ? plan.PasswordManager.StripePlanId + : plan.PasswordManager.StripeSeatPlanId, + Quantity = 1 + } + ] } - ]; - - if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) - { - options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions - { - Type = StripeConstants.TaxIdType.EUVAT, - Value = $"ES{parameters.TaxInformation.TaxId}" - }); - } - } - - if (planType.GetProductTier() == ProductTierType.Families) - { - options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; - } - else - { - options.AutomaticTax = new InvoiceAutomaticTaxOptions - { - Enabled = options.CustomerDetails.Address.Country == "US" || - options.CustomerDetails.TaxIds is [_, ..] }; - } - try - { + if (productType == ProductType.SecretsManager) + { + options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.SecretsManager.StripeSeatPlanId, + Quantity = 1 + }); + + options.Coupon = StripeConstants.CouponIDs.SecretsManagerStandalone; + } + + if (!string.IsNullOrEmpty(taxInformation.TaxId)) + { + var taxIdType = taxService.GetStripeTaxCode( + taxInformation.Country, + taxInformation.TaxId); + + if (string.IsNullOrEmpty(taxIdType)) + { + return new BadRequest( + "We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support for assistance."); + } + + options.CustomerDetails.TaxIds = + [ + new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = taxInformation.TaxId } + ]; + + if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) + { + options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions + { + Type = StripeConstants.TaxIdType.EUVAT, + Value = $"ES{parameters.TaxInformation.TaxId}" + }); + } + } + + if (planType.GetProductTier() == ProductTierType.Families) + { + options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; + } + else + { + options.AutomaticTax = new InvoiceAutomaticTaxOptions + { + Enabled = options.CustomerDetails.Address.Country == "US" || + options.CustomerDetails.TaxIds is [_, ..] + }; + } + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); return Convert.ToDecimal(invoice.Tax) / 100; - } - catch (StripeException stripeException) when (stripeException.StripeError.Code == - StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) - { - return BadRequest.TaxLocationInvalid; - } - catch (StripeException stripeException) when (stripeException.StripeError.Code == - StripeConstants.ErrorCodes.TaxIdInvalid) - { - return BadRequest.TaxIdNumberInvalid; - } - catch (StripeException stripeException) - { - logger.LogError(stripeException, "Stripe responded with an error during {Operation}. Code: {Code}", nameof(PreviewTaxAmountCommand), stripeException.StripeError.Code); - return new Unhandled(); - } - } + }); } #region Command Parameters diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 2a3b619de6..5c0bd3216e 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -155,6 +155,7 @@ public static class FeatureFlagKeys public const string PM20322_AllowTrialLength0 = "pm-20322-allow-trial-length-0"; public const string PM21092_SetNonUSBusinessUseToReverseCharge = "pm-21092-set-non-us-business-use-to-reverse-charge"; public const string PM21383_GetProviderPriceFromStripe = "pm-21383-get-provider-price-from-stripe"; + public const string PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout"; /* Data Insights and Reporting Team */ public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application"; diff --git a/test/Api.Test/Billing/Attributes/InjectOrganizationAttributeTests.cs b/test/Api.Test/Billing/Attributes/InjectOrganizationAttributeTests.cs new file mode 100644 index 0000000000..252c457924 --- /dev/null +++ b/test/Api.Test/Billing/Attributes/InjectOrganizationAttributeTests.cs @@ -0,0 +1,132 @@ +using Bit.Api.Billing.Attributes; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Models.Api; +using Bit.Core.Repositories; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Abstractions; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.AspNetCore.Mvc.ModelBinding; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Billing.Attributes; + +public class InjectOrganizationAttributeTests +{ + private readonly IOrganizationRepository _organizationRepository; + private readonly ActionExecutionDelegate _next; + private readonly ActionExecutingContext _context; + private readonly Organization _organization; + private readonly Guid _organizationId; + + public InjectOrganizationAttributeTests() + { + _organizationRepository = Substitute.For(); + _organizationId = Guid.NewGuid(); + _organization = new Organization { Id = _organizationId }; + + var httpContext = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddScoped(_ => _organizationRepository); + httpContext.RequestServices = services.BuildServiceProvider(); + + var routeData = new RouteData { Values = { ["organizationId"] = _organizationId.ToString() } }; + + var actionContext = new ActionContext( + httpContext, + routeData, + new ActionDescriptor(), + new ModelStateDictionary() + ); + + _next = () => Task.FromResult(new ActionExecutedContext( + actionContext, + new List(), + new object())); + + _context = new ActionExecutingContext( + actionContext, + new List(), + new Dictionary(), + new object()); + } + + [Fact] + public async Task OnActionExecutionAsync_WithExistingOrganization_InjectsOrganization() + { + var attribute = new InjectOrganizationAttribute(); + _organizationRepository.GetByIdAsync(_organizationId) + .Returns(_organization); + + var parameter = new ParameterDescriptor + { + Name = "organization", + ParameterType = typeof(Organization) + }; + _context.ActionDescriptor.Parameters = [parameter]; + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.Equal(_organization, _context.ActionArguments["organization"]); + } + + [Fact] + public async Task OnActionExecutionAsync_WithNonExistentOrganization_ReturnsNotFound() + { + var attribute = new InjectOrganizationAttribute(); + _organizationRepository.GetByIdAsync(_organizationId) + .Returns((Organization)null); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.IsType(_context.Result); + var result = (NotFoundObjectResult)_context.Result; + Assert.IsType(result.Value); + Assert.Equal("Organization not found.", ((ErrorResponseModel)result.Value).Message); + } + + [Fact] + public async Task OnActionExecutionAsync_WithInvalidOrganizationId_ReturnsBadRequest() + { + var attribute = new InjectOrganizationAttribute(); + _context.RouteData.Values["organizationId"] = "not-a-guid"; + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.IsType(_context.Result); + var result = (BadRequestObjectResult)_context.Result; + Assert.IsType(result.Value); + Assert.Equal("Route parameter 'organizationId' is missing or invalid.", ((ErrorResponseModel)result.Value).Message); + } + + [Fact] + public async Task OnActionExecutionAsync_WithMissingOrganizationId_ReturnsBadRequest() + { + var attribute = new InjectOrganizationAttribute(); + _context.RouteData.Values.Clear(); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.IsType(_context.Result); + var result = (BadRequestObjectResult)_context.Result; + Assert.IsType(result.Value); + Assert.Equal("Route parameter 'organizationId' is missing or invalid.", ((ErrorResponseModel)result.Value).Message); + } + + [Fact] + public async Task OnActionExecutionAsync_WithoutOrganizationParameter_ContinuesExecution() + { + var attribute = new InjectOrganizationAttribute(); + _organizationRepository.GetByIdAsync(_organizationId) + .Returns(_organization); + + _context.ActionDescriptor.Parameters = Array.Empty(); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.Empty(_context.ActionArguments); + } +} diff --git a/test/Api.Test/Billing/Attributes/InjectProviderAttributeTests.cs b/test/Api.Test/Billing/Attributes/InjectProviderAttributeTests.cs new file mode 100644 index 0000000000..0a3e19f8b1 --- /dev/null +++ b/test/Api.Test/Billing/Attributes/InjectProviderAttributeTests.cs @@ -0,0 +1,190 @@ +using Bit.Api.Billing.Attributes; +using Bit.Api.Models.Public.Response; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Context; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Abstractions; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.AspNetCore.Mvc.ModelBinding; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Billing.Attributes; + +public class InjectProviderAttributeTests +{ + private readonly IProviderRepository _providerRepository; + private readonly ICurrentContext _currentContext; + private readonly ActionExecutionDelegate _next; + private readonly ActionExecutingContext _context; + private readonly Provider _provider; + private readonly Guid _providerId; + + public InjectProviderAttributeTests() + { + _providerRepository = Substitute.For(); + _currentContext = Substitute.For(); + _providerId = Guid.NewGuid(); + _provider = new Provider { Id = _providerId }; + + var httpContext = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddScoped(_ => _providerRepository); + services.AddScoped(_ => _currentContext); + httpContext.RequestServices = services.BuildServiceProvider(); + + var routeData = new RouteData { Values = { ["providerId"] = _providerId.ToString() } }; + + var actionContext = new ActionContext( + httpContext, + routeData, + new ActionDescriptor(), + new ModelStateDictionary() + ); + + _next = () => Task.FromResult(new ActionExecutedContext( + actionContext, + new List(), + new object())); + + _context = new ActionExecutingContext( + actionContext, + new List(), + new Dictionary(), + new object()); + } + + [Fact] + public async Task OnActionExecutionAsync_WithExistingProvider_InjectsProvider() + { + var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin); + _providerRepository.GetByIdAsync(_providerId).Returns(_provider); + _currentContext.ProviderProviderAdmin(_providerId).Returns(true); + + var parameter = new ParameterDescriptor + { + Name = "provider", + ParameterType = typeof(Provider) + }; + _context.ActionDescriptor.Parameters = [parameter]; + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.Equal(_provider, _context.ActionArguments["provider"]); + } + + [Fact] + public async Task OnActionExecutionAsync_WithNonExistentProvider_ReturnsNotFound() + { + var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin); + _providerRepository.GetByIdAsync(_providerId).Returns((Provider)null); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.IsType(_context.Result); + var result = (NotFoundObjectResult)_context.Result; + Assert.IsType(result.Value); + Assert.Equal("Provider not found.", ((ErrorResponseModel)result.Value).Message); + } + + [Fact] + public async Task OnActionExecutionAsync_WithInvalidProviderId_ReturnsBadRequest() + { + var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin); + _context.RouteData.Values["providerId"] = "not-a-guid"; + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.IsType(_context.Result); + var result = (BadRequestObjectResult)_context.Result; + Assert.IsType(result.Value); + Assert.Equal("Route parameter 'providerId' is missing or invalid.", ((ErrorResponseModel)result.Value).Message); + } + + [Fact] + public async Task OnActionExecutionAsync_WithMissingProviderId_ReturnsBadRequest() + { + var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin); + _context.RouteData.Values.Clear(); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.IsType(_context.Result); + var result = (BadRequestObjectResult)_context.Result; + Assert.IsType(result.Value); + Assert.Equal("Route parameter 'providerId' is missing or invalid.", ((ErrorResponseModel)result.Value).Message); + } + + [Fact] + public async Task OnActionExecutionAsync_WithoutProviderParameter_ContinuesExecution() + { + var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin); + _providerRepository.GetByIdAsync(_providerId).Returns(_provider); + _currentContext.ProviderProviderAdmin(_providerId).Returns(true); + + _context.ActionDescriptor.Parameters = Array.Empty(); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.Empty(_context.ActionArguments); + } + + [Fact] + public async Task OnActionExecutionAsync_UnauthorizedProviderAdmin_ReturnsUnauthorized() + { + var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin); + _providerRepository.GetByIdAsync(_providerId).Returns(_provider); + _currentContext.ProviderProviderAdmin(_providerId).Returns(false); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.IsType(_context.Result); + var result = (UnauthorizedObjectResult)_context.Result; + Assert.IsType(result.Value); + Assert.Equal("Unauthorized.", ((ErrorResponseModel)result.Value).Message); + } + + [Fact] + public async Task OnActionExecutionAsync_UnauthorizedServiceUser_ReturnsUnauthorized() + { + var attribute = new InjectProviderAttribute(ProviderUserType.ServiceUser); + _providerRepository.GetByIdAsync(_providerId).Returns(_provider); + _currentContext.ProviderUser(_providerId).Returns(false); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.IsType(_context.Result); + var result = (UnauthorizedObjectResult)_context.Result; + Assert.IsType(result.Value); + Assert.Equal("Unauthorized.", ((ErrorResponseModel)result.Value).Message); + } + + [Fact] + public async Task OnActionExecutionAsync_AuthorizedProviderAdmin_Succeeds() + { + var attribute = new InjectProviderAttribute(ProviderUserType.ProviderAdmin); + _providerRepository.GetByIdAsync(_providerId).Returns(_provider); + _currentContext.ProviderProviderAdmin(_providerId).Returns(true); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.Null(_context.Result); + } + + [Fact] + public async Task OnActionExecutionAsync_AuthorizedServiceUser_Succeeds() + { + var attribute = new InjectProviderAttribute(ProviderUserType.ServiceUser); + _providerRepository.GetByIdAsync(_providerId).Returns(_provider); + _currentContext.ProviderUser(_providerId).Returns(true); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.Null(_context.Result); + } +} diff --git a/test/Api.Test/Billing/Attributes/InjectUserAttributesTests.cs b/test/Api.Test/Billing/Attributes/InjectUserAttributesTests.cs new file mode 100644 index 0000000000..5c26cca64a --- /dev/null +++ b/test/Api.Test/Billing/Attributes/InjectUserAttributesTests.cs @@ -0,0 +1,129 @@ +using System.Security.Claims; +using Bit.Api.Billing.Attributes; +using Bit.Core.Entities; +using Bit.Core.Models.Api; +using Bit.Core.Services; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Abstractions; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.AspNetCore.Mvc.ModelBinding; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Billing.Attributes; + +public class InjectUserAttributesTests +{ + private readonly IUserService _userService; + private readonly ActionExecutionDelegate _next; + private readonly ActionExecutingContext _context; + private readonly User _user; + + public InjectUserAttributesTests() + { + _userService = Substitute.For(); + _user = new User { Id = Guid.NewGuid() }; + + var httpContext = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddScoped(_ => _userService); + httpContext.RequestServices = services.BuildServiceProvider(); + + var actionContext = new ActionContext( + httpContext, + new RouteData(), + new ActionDescriptor(), + new ModelStateDictionary() + ); + + _next = () => Task.FromResult(new ActionExecutedContext( + actionContext, + new List(), + new object())); + + _context = new ActionExecutingContext( + actionContext, + new List(), + new Dictionary(), + new object()); + } + + [Fact] + public async Task OnActionExecutionAsync_WithAuthorizedUser_InjectsUser() + { + var attribute = new InjectUserAttribute(); + _userService.GetUserByPrincipalAsync(Arg.Any()) + .Returns(_user); + + var parameter = new ParameterDescriptor + { + Name = "user", + ParameterType = typeof(User) + }; + _context.ActionDescriptor.Parameters = [parameter]; + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.Equal(_user, _context.ActionArguments["user"]); + } + + [Fact] + public async Task OnActionExecutionAsync_WithUnauthorizedUser_ReturnsUnauthorized() + { + var attribute = new InjectUserAttribute(); + _userService.GetUserByPrincipalAsync(Arg.Any()) + .Returns((User)null); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.IsType(_context.Result); + var result = (UnauthorizedObjectResult)_context.Result; + Assert.IsType(result.Value); + Assert.Equal("Unauthorized.", ((ErrorResponseModel)result.Value).Message); + } + + [Fact] + public async Task OnActionExecutionAsync_WithoutUserParameter_ContinuesExecution() + { + var attribute = new InjectUserAttribute(); + _userService.GetUserByPrincipalAsync(Arg.Any()) + .Returns(_user); + + _context.ActionDescriptor.Parameters = Array.Empty(); + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.Empty(_context.ActionArguments); + } + + [Fact] + public async Task OnActionExecutionAsync_WithMultipleParameters_InjectsUserCorrectly() + { + var attribute = new InjectUserAttribute(); + _userService.GetUserByPrincipalAsync(Arg.Any()) + .Returns(_user); + + var parameters = new[] + { + new ParameterDescriptor + { + Name = "otherParam", + ParameterType = typeof(string) + }, + new ParameterDescriptor + { + Name = "user", + ParameterType = typeof(User) + } + }; + _context.ActionDescriptor.Parameters = parameters; + + await attribute.OnActionExecutionAsync(_context, _next); + + Assert.Single(_context.ActionArguments); + Assert.Equal(_user, _context.ActionArguments["user"]); + } +} diff --git a/test/Core.Test/Billing/Extensions/StripeExtensions.cs b/test/Core.Test/Billing/Extensions/StripeExtensions.cs new file mode 100644 index 0000000000..44948bbfed --- /dev/null +++ b/test/Core.Test/Billing/Extensions/StripeExtensions.cs @@ -0,0 +1,18 @@ +using Bit.Core.Billing.Payment.Models; +using Stripe; + +namespace Bit.Core.Test.Billing.Extensions; + +public static class StripeExtensions +{ + public static bool HasExpansions(this BaseOptions options, params string[] expansions) + => expansions.All(expansion => options.Expand.Contains(expansion)); + + public static bool Matches(this AddressOptions address, BillingAddress billingAddress) => + address.Country == billingAddress.Country && + address.PostalCode == billingAddress.PostalCode && + address.Line1 == billingAddress.Line1 && + address.Line2 == billingAddress.Line2 && + address.City == billingAddress.City && + address.State == billingAddress.State; +} diff --git a/test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs new file mode 100644 index 0000000000..800c3ec3ae --- /dev/null +++ b/test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs @@ -0,0 +1,94 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Payment.Clients; +using Bit.Core.Billing.Payment.Commands; +using Bit.Core.Entities; +using Bit.Core.Settings; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; +using Invoice = BitPayLight.Models.Invoice.Invoice; + +namespace Bit.Core.Test.Billing.Payment.Commands; + +public class CreateBitPayInvoiceForCreditCommandTests +{ + private readonly IBitPayClient _bitPayClient = Substitute.For(); + private readonly GlobalSettings _globalSettings = new() + { + BitPay = new GlobalSettings.BitPaySettings { NotificationUrl = "https://example.com/bitpay/notification" } + }; + private const string _redirectUrl = "https://bitwarden.com/redirect"; + private readonly CreateBitPayInvoiceForCreditCommand _command; + + public CreateBitPayInvoiceForCreditCommandTests() + { + _command = new CreateBitPayInvoiceForCreditCommand( + _bitPayClient, + _globalSettings, + Substitute.For>()); + } + + [Fact] + public async Task Run_User_CreatesInvoice_ReturnsInvoiceUrl() + { + var user = new User { Id = Guid.NewGuid(), Email = "user@gmail.com" }; + + _bitPayClient.CreateInvoice(Arg.Is(options => + options.Buyer.Email == user.Email && + options.Buyer.Name == user.Email && + options.NotificationUrl == _globalSettings.BitPay.NotificationUrl && + options.PosData == $"userId:{user.Id},accountCredit:1" && + // ReSharper disable once CompareOfFloatsByEqualityOperator + options.Price == Convert.ToDouble(10M) && + options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" }); + + var result = await _command.Run(user, 10M, _redirectUrl); + + Assert.True(result.IsT0); + var invoiceUrl = result.AsT0; + Assert.Equal("https://bitpay.com/invoice/123", invoiceUrl); + } + + [Fact] + public async Task Run_Organization_CreatesInvoice_ReturnsInvoiceUrl() + { + var organization = new Organization { Id = Guid.NewGuid(), BillingEmail = "organization@example.com", Name = "Organization" }; + + _bitPayClient.CreateInvoice(Arg.Is(options => + options.Buyer.Email == organization.BillingEmail && + options.Buyer.Name == organization.Name && + options.NotificationUrl == _globalSettings.BitPay.NotificationUrl && + options.PosData == $"organizationId:{organization.Id},accountCredit:1" && + // ReSharper disable once CompareOfFloatsByEqualityOperator + options.Price == Convert.ToDouble(10M) && + options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" }); + + var result = await _command.Run(organization, 10M, _redirectUrl); + + Assert.True(result.IsT0); + var invoiceUrl = result.AsT0; + Assert.Equal("https://bitpay.com/invoice/123", invoiceUrl); + } + + [Fact] + public async Task Run_Provider_CreatesInvoice_ReturnsInvoiceUrl() + { + var provider = new Provider { Id = Guid.NewGuid(), BillingEmail = "organization@example.com", Name = "Provider" }; + + _bitPayClient.CreateInvoice(Arg.Is(options => + options.Buyer.Email == provider.BillingEmail && + options.Buyer.Name == provider.Name && + options.NotificationUrl == _globalSettings.BitPay.NotificationUrl && + options.PosData == $"providerId:{provider.Id},accountCredit:1" && + // ReSharper disable once CompareOfFloatsByEqualityOperator + options.Price == Convert.ToDouble(10M) && + options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" }); + + var result = await _command.Run(provider, 10M, _redirectUrl); + + Assert.True(result.IsT0); + var invoiceUrl = result.AsT0; + Assert.Equal("https://bitpay.com/invoice/123", invoiceUrl); + } +} diff --git a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs new file mode 100644 index 0000000000..453d0c78e9 --- /dev/null +++ b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs @@ -0,0 +1,349 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Payment.Commands; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Services; +using Bit.Core.Test.Billing.Extensions; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Payment.Commands; + +using static StripeConstants; + +public class UpdateBillingAddressCommandTests +{ + private readonly IStripeAdapter _stripeAdapter; + private readonly UpdateBillingAddressCommand _command; + + public UpdateBillingAddressCommandTests() + { + _stripeAdapter = Substitute.For(); + _command = new UpdateBillingAddressCommand( + Substitute.For>(), + _stripeAdapter); + } + + [Fact] + public async Task Run_PersonalOrganization_MakesCorrectInvocations_ReturnsBillingAddress() + { + var organization = new Organization + { + PlanType = PlanType.FamiliesAnnually, + GatewayCustomerId = "cus_123", + GatewaySubscriptionId = "sub_123" + }; + + var input = new BillingAddress + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Suite 100", + City = "New York", + State = "NY" + }; + + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Suite 100", + City = "New York", + State = "NY" + }, + Subscriptions = new StripeList + { + Data = + [ + new Subscription + { + Id = organization.GatewaySubscriptionId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + } + }; + + _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Address.Matches(input) && + options.HasExpansions("subscriptions") + )).Returns(customer); + + var result = await _command.Run(organization, input); + + Assert.True(result.IsT0); + var output = result.AsT0; + Assert.Equivalent(input, output); + + await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); + } + + [Fact] + public async Task Run_BusinessOrganization_MakesCorrectInvocations_ReturnsBillingAddress() + { + var organization = new Organization + { + PlanType = PlanType.EnterpriseAnnually, + GatewayCustomerId = "cus_123", + GatewaySubscriptionId = "sub_123" + }; + + var input = new BillingAddress + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Suite 100", + City = "New York", + State = "NY" + }; + + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Suite 100", + City = "New York", + State = "NY" + }, + Subscriptions = new StripeList + { + Data = + [ + new Subscription + { + Id = organization.GatewaySubscriptionId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + } + }; + + _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Address.Matches(input) && + options.HasExpansions("subscriptions", "tax_ids") && + options.TaxExempt == TaxExempt.None + )).Returns(customer); + + var result = await _command.Run(organization, input); + + Assert.True(result.IsT0); + var output = result.AsT0; + Assert.Equivalent(input, output); + + await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); + } + + [Fact] + public async Task Run_BusinessOrganization_RemovingTaxId_MakesCorrectInvocations_ReturnsBillingAddress() + { + var organization = new Organization + { + PlanType = PlanType.EnterpriseAnnually, + GatewayCustomerId = "cus_123", + GatewaySubscriptionId = "sub_123" + }; + + var input = new BillingAddress + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Suite 100", + City = "New York", + State = "NY" + }; + + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Suite 100", + City = "New York", + State = "NY" + }, + Id = organization.GatewayCustomerId, + Subscriptions = new StripeList + { + Data = + [ + new Subscription + { + Id = organization.GatewaySubscriptionId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + }, + TaxIds = new StripeList + { + Data = + [ + new TaxId { Id = "tax_id_123", Type = "us_ein", Value = "123456789" } + ] + } + }; + + _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Address.Matches(input) && + options.HasExpansions("subscriptions", "tax_ids") && + options.TaxExempt == TaxExempt.None + )).Returns(customer); + + var result = await _command.Run(organization, input); + + Assert.True(result.IsT0); + var output = result.AsT0; + Assert.Equivalent(input, output); + + await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); + + await _stripeAdapter.Received(1).TaxIdDeleteAsync(customer.Id, "tax_id_123"); + } + + [Fact] + public async Task Run_NonUSBusinessOrganization_MakesCorrectInvocations_ReturnsBillingAddress() + { + var organization = new Organization + { + PlanType = PlanType.EnterpriseAnnually, + GatewayCustomerId = "cus_123", + GatewaySubscriptionId = "sub_123" + }; + + var input = new BillingAddress + { + Country = "DE", + PostalCode = "10115", + Line1 = "Friedrichstraße 123", + Line2 = "Stock 3", + City = "Berlin", + State = "Berlin" + }; + + var customer = new Customer + { + Address = new Address + { + Country = "DE", + PostalCode = "10115", + Line1 = "Friedrichstraße 123", + Line2 = "Stock 3", + City = "Berlin", + State = "Berlin" + }, + Subscriptions = new StripeList + { + Data = + [ + new Subscription + { + Id = organization.GatewaySubscriptionId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + } + }; + + _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Address.Matches(input) && + options.HasExpansions("subscriptions", "tax_ids") && + options.TaxExempt == TaxExempt.Reverse + )).Returns(customer); + + var result = await _command.Run(organization, input); + + Assert.True(result.IsT0); + var output = result.AsT0; + Assert.Equivalent(input, output); + + await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); + } + + [Fact] + public async Task Run_BusinessOrganizationWithSpanishCIF_MakesCorrectInvocations_ReturnsBillingAddress() + { + var organization = new Organization + { + PlanType = PlanType.EnterpriseAnnually, + GatewayCustomerId = "cus_123", + GatewaySubscriptionId = "sub_123" + }; + + var input = new BillingAddress + { + Country = "ES", + PostalCode = "28001", + Line1 = "Calle de Serrano 41", + Line2 = "Planta 3", + City = "Madrid", + State = "Madrid", + TaxId = new TaxID(TaxIdType.SpanishNIF, "A12345678") + }; + + var customer = new Customer + { + Address = new Address + { + Country = "ES", + PostalCode = "28001", + Line1 = "Calle de Serrano 41", + Line2 = "Planta 3", + City = "Madrid", + State = "Madrid" + }, + Id = organization.GatewayCustomerId, + Subscriptions = new StripeList + { + Data = + [ + new Subscription + { + Id = organization.GatewaySubscriptionId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + } + }; + + _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Address.Matches(input) && + options.HasExpansions("subscriptions", "tax_ids") && + options.TaxExempt == TaxExempt.Reverse + )).Returns(customer); + + _stripeAdapter + .TaxIdCreateAsync(customer.Id, + Arg.Is(options => options.Type == TaxIdType.EUVAT)) + .Returns(new TaxId { Type = TaxIdType.EUVAT, Value = "ESA12345678" }); + + var result = await _command.Run(organization, input); + + Assert.True(result.IsT0); + var output = result.AsT0; + Assert.Equivalent(input with { TaxId = new TaxID(TaxIdType.EUVAT, "ESA12345678") }, output); + + await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); + + await _stripeAdapter.Received(1).TaxIdCreateAsync(organization.GatewayCustomerId, Arg.Is( + options => options.Type == TaxIdType.SpanishNIF && + options.Value == input.TaxId.Value)); + } +} diff --git a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs new file mode 100644 index 0000000000..e7bc5c787c --- /dev/null +++ b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs @@ -0,0 +1,399 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Payment.Commands; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Services; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Test.Billing.Extensions; +using Braintree; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Stripe; +using Xunit; +using Address = Stripe.Address; +using Customer = Stripe.Customer; +using PaymentMethod = Stripe.PaymentMethod; + +namespace Bit.Core.Test.Billing.Payment.Commands; + +using static StripeConstants; + +public class UpdatePaymentMethodCommandTests +{ + private readonly IBraintreeGateway _braintreeGateway = Substitute.For(); + private readonly IGlobalSettings _globalSettings = Substitute.For(); + private readonly ISetupIntentCache _setupIntentCache = Substitute.For(); + private readonly IStripeAdapter _stripeAdapter = Substitute.For(); + private readonly ISubscriberService _subscriberService = Substitute.For(); + private readonly UpdatePaymentMethodCommand _command; + + public UpdatePaymentMethodCommandTests() + { + _command = new UpdatePaymentMethodCommand( + _braintreeGateway, + _globalSettings, + Substitute.For>(), + _setupIntentCache, + _stripeAdapter, + _subscriberService); + } + + [Fact] + public async Task Run_BankAccount_MakesCorrectInvocations_ReturnsMaskedBankAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345" + }, + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + const string token = "TOKEN"; + + var setupIntent = new SetupIntent + { + Id = "seti_123", + PaymentMethod = + new PaymentMethod + { + Type = "us_bank_account", + UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } + }, + NextAction = new SetupIntentNextAction + { + VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits() + }, + Status = "requires_action" + }; + + _stripeAdapter.SetupIntentList(Arg.Is(options => + options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); + + var result = await _command.Run(organization, + new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = token }, new BillingAddress + { + Country = "US", + PostalCode = "12345" + }); + + Assert.True(result.IsT0); + var maskedPaymentMethod = result.AsT0; + Assert.True(maskedPaymentMethod.IsT0); + var maskedBankAccount = maskedPaymentMethod.AsT0; + Assert.Equal("Chase", maskedBankAccount.BankName); + Assert.Equal("9999", maskedBankAccount.Last4); + Assert.False(maskedBankAccount.Verified); + + await _setupIntentCache.Received(1).Set(organization.Id, setupIntent.Id); + } + + [Fact] + public async Task Run_BankAccount_StripeToPayPal_MakesCorrectInvocations_ReturnsMaskedBankAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345" + }, + Id = "cus_123", + Metadata = new Dictionary + { + [MetadataKeys.BraintreeCustomerId] = "braintree_customer_id" + } + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + const string token = "TOKEN"; + + var setupIntent = new SetupIntent + { + Id = "seti_123", + PaymentMethod = + new PaymentMethod + { + Type = "us_bank_account", + UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } + }, + NextAction = new SetupIntentNextAction + { + VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits() + }, + Status = "requires_action" + }; + + _stripeAdapter.SetupIntentList(Arg.Is(options => + options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); + + var result = await _command.Run(organization, + new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = token }, new BillingAddress + { + Country = "US", + PostalCode = "12345" + }); + + Assert.True(result.IsT0); + var maskedPaymentMethod = result.AsT0; + Assert.True(maskedPaymentMethod.IsT0); + var maskedBankAccount = maskedPaymentMethod.AsT0; + Assert.Equal("Chase", maskedBankAccount.BankName); + Assert.Equal("9999", maskedBankAccount.Last4); + Assert.False(maskedBankAccount.Verified); + + await _setupIntentCache.Received(1).Set(organization.Id, setupIntent.Id); + await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, Arg.Is(options => + options.Metadata[MetadataKeys.BraintreeCustomerId] == string.Empty && + options.Metadata[MetadataKeys.RetiredBraintreeCustomerId] == "braintree_customer_id")); + } + + [Fact] + public async Task Run_Card_MakesCorrectInvocations_ReturnsMaskedCard() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345" + }, + Id = "cus_123", + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + const string token = "TOKEN"; + + _stripeAdapter + .PaymentMethodAttachAsync(token, + Arg.Is(options => options.Customer == customer.Id)) + .Returns(new PaymentMethod + { + Type = "card", + Card = new PaymentMethodCard + { + Brand = "visa", + Last4 = "9999", + ExpMonth = 1, + ExpYear = 2028 + } + }); + + var result = await _command.Run(organization, + new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = token }, new BillingAddress + { + Country = "US", + PostalCode = "12345" + }); + + Assert.True(result.IsT0); + var maskedPaymentMethod = result.AsT0; + Assert.True(maskedPaymentMethod.IsT1); + var maskedCard = maskedPaymentMethod.AsT1; + Assert.Equal("visa", maskedCard.Brand); + Assert.Equal("9999", maskedCard.Last4); + Assert.Equal("01/2028", maskedCard.Expiration); + + await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + Arg.Is(options => options.InvoiceSettings.DefaultPaymentMethod == token)); + } + + [Fact] + public async Task Run_Card_PropagateBillingAddress_MakesCorrectInvocations_ReturnsMaskedCard() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + Id = "cus_123", + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + const string token = "TOKEN"; + + _stripeAdapter + .PaymentMethodAttachAsync(token, + Arg.Is(options => options.Customer == customer.Id)) + .Returns(new PaymentMethod + { + Type = "card", + Card = new PaymentMethodCard + { + Brand = "visa", + Last4 = "9999", + ExpMonth = 1, + ExpYear = 2028 + } + }); + + var result = await _command.Run(organization, + new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = token }, new BillingAddress + { + Country = "US", + PostalCode = "12345" + }); + + Assert.True(result.IsT0); + var maskedPaymentMethod = result.AsT0; + Assert.True(maskedPaymentMethod.IsT1); + var maskedCard = maskedPaymentMethod.AsT1; + Assert.Equal("visa", maskedCard.Brand); + Assert.Equal("9999", maskedCard.Last4); + Assert.Equal("01/2028", maskedCard.Expiration); + + await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + Arg.Is(options => options.InvoiceSettings.DefaultPaymentMethod == token)); + + await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + Arg.Is(options => options.Address.Country == "US" && options.Address.PostalCode == "12345")); + } + + [Fact] + public async Task Run_PayPal_ExistingBraintreeCustomer_MakesCorrectInvocations_ReturnsMaskedPayPalAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345" + }, + Id = "cus_123", + Metadata = new Dictionary + { + [MetadataKeys.BraintreeCustomerId] = "braintree_customer_id" + } + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + var customerGateway = Substitute.For(); + var braintreeCustomer = Substitute.For(); + braintreeCustomer.Id.Returns("braintree_customer_id"); + var existing = Substitute.For(); + existing.Email.Returns("user@gmail.com"); + existing.IsDefault.Returns(true); + existing.Token.Returns("EXISTING"); + braintreeCustomer.PaymentMethods.Returns([existing]); + customerGateway.FindAsync("braintree_customer_id").Returns(braintreeCustomer); + _braintreeGateway.Customer.Returns(customerGateway); + + var paymentMethodGateway = Substitute.For(); + var updated = Substitute.For(); + updated.Email.Returns("user@gmail.com"); + updated.Token.Returns("UPDATED"); + var updatedResult = Substitute.For>(); + updatedResult.Target.Returns(updated); + paymentMethodGateway.CreateAsync(Arg.Is(options => + options.CustomerId == braintreeCustomer.Id && options.PaymentMethodNonce == "TOKEN")) + .Returns(updatedResult); + _braintreeGateway.PaymentMethod.Returns(paymentMethodGateway); + + var result = await _command.Run(organization, + new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "TOKEN" }, + new BillingAddress { Country = "US", PostalCode = "12345" }); + + Assert.True(result.IsT0); + var maskedPaymentMethod = result.AsT0; + Assert.True(maskedPaymentMethod.IsT2); + var maskedPayPalAccount = maskedPaymentMethod.AsT2; + Assert.Equal("user@gmail.com", maskedPayPalAccount.Email); + + await customerGateway.Received(1).UpdateAsync(braintreeCustomer.Id, + Arg.Is(options => options.DefaultPaymentMethodToken == updated.Token)); + await paymentMethodGateway.Received(1).DeleteAsync(existing.Token); + } + + [Fact] + public async Task Run_PayPal_NewBraintreeCustomer_MakesCorrectInvocations_ReturnsMaskedPayPalAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345" + }, + Id = "cus_123", + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + _globalSettings.BaseServiceUri.Returns(new GlobalSettings.BaseServiceUriSettings(new GlobalSettings()) + { + CloudRegion = "US" + }); + + var customerGateway = Substitute.For(); + var braintreeCustomer = Substitute.For(); + braintreeCustomer.Id.Returns("braintree_customer_id"); + var payPalAccount = Substitute.For(); + payPalAccount.Email.Returns("user@gmail.com"); + payPalAccount.IsDefault.Returns(true); + payPalAccount.Token.Returns("NONCE"); + braintreeCustomer.PaymentMethods.Returns([payPalAccount]); + var createResult = Substitute.For>(); + createResult.Target.Returns(braintreeCustomer); + customerGateway.CreateAsync(Arg.Is(options => + options.Id.StartsWith(organization.BraintreeCustomerIdPrefix() + organization.Id.ToString("N").ToLower()) && + options.CustomFields[organization.BraintreeIdField()] == organization.Id.ToString() && + options.CustomFields[organization.BraintreeCloudRegionField()] == "US" && + options.Email == organization.BillingEmailAddress() && + options.PaymentMethodNonce == "TOKEN")).Returns(createResult); + _braintreeGateway.Customer.Returns(customerGateway); + + var result = await _command.Run(organization, + new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "TOKEN" }, + new BillingAddress { Country = "US", PostalCode = "12345" }); + + Assert.True(result.IsT0); + var maskedPaymentMethod = result.AsT0; + Assert.True(maskedPaymentMethod.IsT2); + var maskedPayPalAccount = maskedPaymentMethod.AsT2; + Assert.Equal("user@gmail.com", maskedPayPalAccount.Email); + + await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + Arg.Is(options => + options.Metadata[MetadataKeys.BraintreeCustomerId] == "braintree_customer_id")); + } +} diff --git a/test/Core.Test/Billing/Payment/Commands/VerifyBankAccountCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/VerifyBankAccountCommandTests.cs new file mode 100644 index 0000000000..4be5539cc8 --- /dev/null +++ b/test/Core.Test/Billing/Payment/Commands/VerifyBankAccountCommandTests.cs @@ -0,0 +1,81 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Payment.Commands; +using Bit.Core.Services; +using Bit.Core.Test.Billing.Extensions; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Payment.Commands; + +public class VerifyBankAccountCommandTests +{ + private readonly ISetupIntentCache _setupIntentCache = Substitute.For(); + private readonly IStripeAdapter _stripeAdapter = Substitute.For(); + private readonly VerifyBankAccountCommand _command; + + public VerifyBankAccountCommandTests() + { + _command = new VerifyBankAccountCommand( + Substitute.For>(), + _setupIntentCache, + _stripeAdapter); + } + + [Fact] + public async Task Run_MakesCorrectInvocations_ReturnsMaskedBankAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid(), + GatewayCustomerId = "cus_123" + }; + + const string setupIntentId = "seti_123"; + + _setupIntentCache.Get(organization.Id).Returns(setupIntentId); + + var setupIntent = new SetupIntent + { + Id = setupIntentId, + PaymentMethodId = "pm_123", + PaymentMethod = + new PaymentMethod + { + Id = "pm_123", + Type = "us_bank_account", + UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } + }, + NextAction = new SetupIntentNextAction + { + VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits() + }, + Status = "requires_action" + }; + + _stripeAdapter.SetupIntentGet(setupIntentId, + Arg.Is(options => options.HasExpansions("payment_method"))).Returns(setupIntent); + + _stripeAdapter.PaymentMethodAttachAsync(setupIntent.PaymentMethodId, + Arg.Is(options => options.Customer == organization.GatewayCustomerId)) + .Returns(setupIntent.PaymentMethod); + + var result = await _command.Run(organization, "DESCRIPTOR_CODE"); + + Assert.True(result.IsT0); + var maskedPaymentMethod = result.AsT0; + Assert.True(maskedPaymentMethod.IsT0); + var maskedBankAccount = maskedPaymentMethod.AsT0; + Assert.Equal("Chase", maskedBankAccount.BankName); + Assert.Equal("9999", maskedBankAccount.Last4); + Assert.True(maskedBankAccount.Verified); + + await _stripeAdapter.Received(1).SetupIntentVerifyMicroDeposit(setupIntent.Id, + Arg.Is(options => options.DescriptorCode == "DESCRIPTOR_CODE")); + + await _stripeAdapter.Received(1).CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is( + options => options.InvoiceSettings.DefaultPaymentMethod == setupIntent.PaymentMethodId)); + } +} diff --git a/test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs b/test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs new file mode 100644 index 0000000000..345f2dfab8 --- /dev/null +++ b/test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs @@ -0,0 +1,63 @@ +using System.Text.Json; +using Bit.Core.Billing.Payment.Models; +using Xunit; + +namespace Bit.Core.Test.Billing.Payment.Models; + +public class MaskedPaymentMethodTests +{ + [Fact] + public void Write_Read_BankAccount_Succeeds() + { + MaskedPaymentMethod input = new MaskedBankAccount + { + BankName = "Chase", + Last4 = "9999", + Verified = true + }; + + var json = JsonSerializer.Serialize(input); + + var output = JsonSerializer.Deserialize(json); + Assert.NotNull(output); + Assert.True(output.IsT0); + + Assert.Equivalent(input.AsT0, output.AsT0); + } + + [Fact] + public void Write_Read_Card_Succeeds() + { + MaskedPaymentMethod input = new MaskedCard + { + Brand = "visa", + Last4 = "9999", + Expiration = "01/2028" + }; + + var json = JsonSerializer.Serialize(input); + + var output = JsonSerializer.Deserialize(json); + Assert.NotNull(output); + Assert.True(output.IsT1); + + Assert.Equivalent(input.AsT1, output.AsT1); + } + + [Fact] + public void Write_Read_PayPal_Succeeds() + { + MaskedPaymentMethod input = new MaskedPayPalAccount + { + Email = "paypal-user@gmail.com" + }; + + var json = JsonSerializer.Serialize(input); + + var output = JsonSerializer.Deserialize(json); + Assert.NotNull(output); + Assert.True(output.IsT2); + + Assert.Equivalent(input.AsT2, output.AsT2); + } +} diff --git a/test/Core.Test/Billing/Payment/Queries/GetBillingAddressQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetBillingAddressQueryTests.cs new file mode 100644 index 0000000000..048c143a0e --- /dev/null +++ b/test/Core.Test/Billing/Payment/Queries/GetBillingAddressQueryTests.cs @@ -0,0 +1,204 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Test.Billing.Extensions; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Payment.Queries; + +public class GetBillingAddressQueryTests +{ + private readonly ISubscriberService _subscriberService = Substitute.For(); + private readonly GetBillingAddressQuery _query; + + public GetBillingAddressQueryTests() + { + _query = new GetBillingAddressQuery(_subscriberService); + } + + [Fact] + public async Task Run_ForUserWithNoAddress_ReturnsNull() + { + var user = new User(); + + var customer = new Customer(); + + _subscriberService.GetCustomer(user, Arg.Is( + options => options.Expand == null)).Returns(customer); + + var billingAddress = await _query.Run(user); + + Assert.Null(billingAddress); + } + + [Fact] + public async Task Run_ForUserWithAddress_ReturnsBillingAddress() + { + var user = new User(); + + var address = GetAddress(); + + var customer = new Customer + { + Address = address + }; + + _subscriberService.GetCustomer(user, Arg.Is( + options => options.Expand == null)).Returns(customer); + + var billingAddress = await _query.Run(user); + + AssertEquality(address, billingAddress); + } + + [Fact] + public async Task Run_ForPersonalOrganizationWithNoAddress_ReturnsNull() + { + var organization = new Organization + { + PlanType = PlanType.FamiliesAnnually + }; + + var customer = new Customer(); + + _subscriberService.GetCustomer(organization, Arg.Is( + options => options.Expand == null)).Returns(customer); + + var billingAddress = await _query.Run(organization); + + Assert.Null(billingAddress); + } + + [Fact] + public async Task Run_ForPersonalOrganizationWithAddress_ReturnsBillingAddress() + { + var organization = new Organization + { + PlanType = PlanType.FamiliesAnnually + }; + + var address = GetAddress(); + + var customer = new Customer + { + Address = address + }; + + _subscriberService.GetCustomer(organization, Arg.Is( + options => options.Expand == null)).Returns(customer); + + var billingAddress = await _query.Run(organization); + + AssertEquality(customer.Address, billingAddress); + } + + [Fact] + public async Task Run_ForBusinessOrganizationWithNoAddress_ReturnsNull() + { + var organization = new Organization + { + PlanType = PlanType.EnterpriseAnnually + }; + + var customer = new Customer(); + + _subscriberService.GetCustomer(organization, Arg.Is( + options => options.HasExpansions("tax_ids"))).Returns(customer); + + var billingAddress = await _query.Run(organization); + + Assert.Null(billingAddress); + } + + [Fact] + public async Task Run_ForBusinessOrganizationWithAddressAndTaxId_ReturnsBillingAddressWithTaxId() + { + var organization = new Organization + { + PlanType = PlanType.EnterpriseAnnually + }; + + var address = GetAddress(); + + var taxId = GetTaxId(); + + var customer = new Customer + { + Address = address, + TaxIds = new StripeList + { + Data = [taxId] + } + }; + + _subscriberService.GetCustomer(organization, Arg.Is( + options => options.HasExpansions("tax_ids"))).Returns(customer); + + var billingAddress = await _query.Run(organization); + + AssertEquality(address, taxId, billingAddress); + } + + [Fact] + public async Task Run_ForProviderWithAddressAndTaxId_ReturnsBillingAddressWithTaxId() + { + var provider = new Provider(); + + var address = GetAddress(); + + var taxId = GetTaxId(); + + var customer = new Customer + { + Address = address, + TaxIds = new StripeList + { + Data = [taxId] + } + }; + + _subscriberService.GetCustomer(provider, Arg.Is( + options => options.HasExpansions("tax_ids"))).Returns(customer); + + var billingAddress = await _query.Run(provider); + + AssertEquality(address, taxId, billingAddress); + } + + private static void AssertEquality(Address address, BillingAddress? billingAddress) + { + Assert.NotNull(billingAddress); + Assert.Equal(address.Country, billingAddress.Country); + Assert.Equal(address.PostalCode, billingAddress.PostalCode); + Assert.Equal(address.Line1, billingAddress.Line1); + Assert.Equal(address.Line2, billingAddress.Line2); + Assert.Equal(address.City, billingAddress.City); + Assert.Equal(address.State, billingAddress.State); + } + + private static void AssertEquality(Address address, TaxId taxId, BillingAddress? billingAddress) + { + AssertEquality(address, billingAddress); + Assert.NotNull(billingAddress!.TaxId); + Assert.Equal(taxId.Type, billingAddress.TaxId!.Code); + Assert.Equal(taxId.Value, billingAddress.TaxId!.Value); + } + + private static Address GetAddress() => new() + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Suite 100", + City = "New York", + State = "NY" + }; + + private static TaxId GetTaxId() => new() { Type = "us_ein", Value = "123456789" }; +} diff --git a/test/Core.Test/Billing/Payment/Queries/GetCreditQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetCreditQueryTests.cs new file mode 100644 index 0000000000..55f5e85009 --- /dev/null +++ b/test/Core.Test/Billing/Payment/Queries/GetCreditQueryTests.cs @@ -0,0 +1,41 @@ +using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Payment.Queries; + +public class GetCreditQueryTests +{ + private readonly ISubscriberService _subscriberService = Substitute.For(); + private readonly GetCreditQuery _query; + + public GetCreditQueryTests() + { + _query = new GetCreditQuery(_subscriberService); + } + + [Fact] + public async Task Run_NoCustomer_ReturnsNull() + { + _subscriberService.GetCustomer(Arg.Any()).ReturnsNull(); + + var credit = await _query.Run(Substitute.For()); + + Assert.Null(credit); + } + + [Fact] + public async Task Run_ReturnsCredit() + { + _subscriberService.GetCustomer(Arg.Any()).Returns(new Customer { Balance = -1000 }); + + var credit = await _query.Run(Substitute.For()); + + Assert.NotNull(credit); + Assert.Equal(10M, credit); + } +} diff --git a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs new file mode 100644 index 0000000000..4d82b4b5c9 --- /dev/null +++ b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs @@ -0,0 +1,327 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Services; +using Bit.Core.Services; +using Bit.Core.Test.Billing.Extensions; +using Braintree; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Stripe; +using Xunit; +using Customer = Stripe.Customer; +using PaymentMethod = Stripe.PaymentMethod; + +namespace Bit.Core.Test.Billing.Payment.Queries; + +using static StripeConstants; + +public class GetPaymentMethodQueryTests +{ + private readonly IBraintreeGateway _braintreeGateway = Substitute.For(); + private readonly ISetupIntentCache _setupIntentCache = Substitute.For(); + private readonly IStripeAdapter _stripeAdapter = Substitute.For(); + private readonly ISubscriberService _subscriberService = Substitute.For(); + private readonly GetPaymentMethodQuery _query; + + public GetPaymentMethodQueryTests() + { + _query = new GetPaymentMethodQuery( + _braintreeGateway, + Substitute.For>(), + _setupIntentCache, + _stripeAdapter, + _subscriberService); + } + + [Fact] + public async Task Run_NoPaymentMethod_ReturnsNull() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization, + Arg.Is(options => + options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer); + + var maskedPaymentMethod = await _query.Run(organization); + + Assert.Null(maskedPaymentMethod); + } + + [Fact] + public async Task Run_BankAccount_FromPaymentMethod_ReturnsMaskedBankAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings + { + DefaultPaymentMethod = new PaymentMethod + { + Type = "us_bank_account", + UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } + } + }, + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization, + Arg.Is(options => + options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer); + + var maskedPaymentMethod = await _query.Run(organization); + + Assert.NotNull(maskedPaymentMethod); + Assert.True(maskedPaymentMethod.IsT0); + var maskedBankAccount = maskedPaymentMethod.AsT0; + Assert.Equal("Chase", maskedBankAccount.BankName); + Assert.Equal("9999", maskedBankAccount.Last4); + Assert.True(maskedBankAccount.Verified); + } + + [Fact] + public async Task Run_BankAccount_FromSource_ReturnsMaskedBankAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + DefaultSource = new BankAccount + { + BankName = "Chase", + Last4 = "9999", + Status = "verified" + }, + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization, + Arg.Is(options => + options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer); + + var maskedPaymentMethod = await _query.Run(organization); + + Assert.NotNull(maskedPaymentMethod); + Assert.True(maskedPaymentMethod.IsT0); + var maskedBankAccount = maskedPaymentMethod.AsT0; + Assert.Equal("Chase", maskedBankAccount.BankName); + Assert.Equal("9999", maskedBankAccount.Last4); + Assert.True(maskedBankAccount.Verified); + } + + [Fact] + public async Task Run_BankAccount_FromSetupIntent_ReturnsMaskedBankAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization, + Arg.Is(options => + options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer); + + _setupIntentCache.Get(organization.Id).Returns("seti_123"); + + _stripeAdapter + .SetupIntentGet("seti_123", + Arg.Is(options => options.HasExpansions("payment_method"))).Returns( + new SetupIntent + { + PaymentMethod = new PaymentMethod + { + Type = "us_bank_account", + UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } + }, + NextAction = new SetupIntentNextAction + { + VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits() + }, + Status = "requires_action" + }); + + var maskedPaymentMethod = await _query.Run(organization); + + Assert.NotNull(maskedPaymentMethod); + Assert.True(maskedPaymentMethod.IsT0); + var maskedBankAccount = maskedPaymentMethod.AsT0; + Assert.Equal("Chase", maskedBankAccount.BankName); + Assert.Equal("9999", maskedBankAccount.Last4); + Assert.False(maskedBankAccount.Verified); + } + + [Fact] + public async Task Run_Card_FromPaymentMethod_ReturnsMaskedCard() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings + { + DefaultPaymentMethod = new PaymentMethod + { + Type = "card", + Card = new PaymentMethodCard + { + Brand = "visa", + Last4 = "9999", + ExpMonth = 1, + ExpYear = 2028 + } + } + }, + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization, + Arg.Is(options => + options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer); + + var maskedPaymentMethod = await _query.Run(organization); + + Assert.NotNull(maskedPaymentMethod); + Assert.True(maskedPaymentMethod.IsT1); + var maskedCard = maskedPaymentMethod.AsT1; + Assert.Equal("visa", maskedCard.Brand); + Assert.Equal("9999", maskedCard.Last4); + Assert.Equal("01/2028", maskedCard.Expiration); + } + + [Fact] + public async Task Run_Card_FromSource_ReturnsMaskedCard() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + DefaultSource = new Card + { + Brand = "visa", + Last4 = "9999", + ExpMonth = 1, + ExpYear = 2028 + }, + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization, + Arg.Is(options => + options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer); + + var maskedPaymentMethod = await _query.Run(organization); + + Assert.NotNull(maskedPaymentMethod); + Assert.True(maskedPaymentMethod.IsT1); + var maskedCard = maskedPaymentMethod.AsT1; + Assert.Equal("visa", maskedCard.Brand); + Assert.Equal("9999", maskedCard.Last4); + Assert.Equal("01/2028", maskedCard.Expiration); + } + + [Fact] + public async Task Run_Card_FromSourceCard_ReturnsMaskedCard() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + DefaultSource = new Source + { + Card = new SourceCard + { + Brand = "Visa", + Last4 = "9999", + ExpMonth = 1, + ExpYear = 2028 + } + }, + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization, + Arg.Is(options => + options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer); + + var maskedPaymentMethod = await _query.Run(organization); + + Assert.NotNull(maskedPaymentMethod); + Assert.True(maskedPaymentMethod.IsT1); + var maskedCard = maskedPaymentMethod.AsT1; + Assert.Equal("visa", maskedCard.Brand); + Assert.Equal("9999", maskedCard.Last4); + Assert.Equal("01/2028", maskedCard.Expiration); + } + + [Fact] + public async Task Run_PayPalAccount_ReturnsMaskedPayPalAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + Metadata = new Dictionary + { + [MetadataKeys.BraintreeCustomerId] = "braintree_customer_id" + } + }; + + _subscriberService.GetCustomer(organization, + Arg.Is(options => + options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer); + + var customerGateway = Substitute.For(); + var braintreeCustomer = Substitute.For(); + var payPalAccount = Substitute.For(); + payPalAccount.Email.Returns("user@gmail.com"); + payPalAccount.IsDefault.Returns(true); + braintreeCustomer.PaymentMethods.Returns([payPalAccount]); + customerGateway.FindAsync("braintree_customer_id").Returns(braintreeCustomer); + _braintreeGateway.Customer.Returns(customerGateway); + + var maskedPaymentMethod = await _query.Run(organization); + + Assert.NotNull(maskedPaymentMethod); + Assert.True(maskedPaymentMethod.IsT2); + var maskedPayPalAccount = maskedPaymentMethod.AsT2; + Assert.Equal("user@gmail.com", maskedPayPalAccount.Email); + } +} diff --git a/test/Core.Test/Billing/Tax/Commands/PreviewTaxAmountCommandTests.cs b/test/Core.Test/Billing/Tax/Commands/PreviewTaxAmountCommandTests.cs index c35dc275e6..ee5625d522 100644 --- a/test/Core.Test/Billing/Tax/Commands/PreviewTaxAmountCommandTests.cs +++ b/test/Core.Test/Billing/Tax/Commands/PreviewTaxAmountCommandTests.cs @@ -1,6 +1,5 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Tax.Commands; using Bit.Core.Billing.Tax.Services; @@ -8,7 +7,6 @@ using Bit.Core.Services; using Bit.Core.Utilities; using Microsoft.Extensions.Logging; using NSubstitute; -using NSubstitute.ExceptionExtensions; using Stripe; using Xunit; using static Bit.Core.Billing.Tax.Commands.OrganizationTrialParameters; @@ -273,74 +271,6 @@ public class PreviewTaxAmountCommandTests // Assert Assert.True(result.IsT1); var badRequest = result.AsT1; - Assert.Equal(BillingErrorTranslationKeys.UnknownTaxIdType, badRequest.TranslationKey); - } - - [Fact] - public async Task Run_CustomerTaxLocationInvalid_BadRequest() - { - // Arrange - var parameters = new OrganizationTrialParameters - { - PlanType = PlanType.EnterpriseAnnually, - ProductType = ProductType.PasswordManager, - TaxInformation = new TaxInformationDTO - { - Country = "US", - PostalCode = "12345" - } - }; - - var plan = StaticStore.GetPlan(parameters.PlanType); - - _pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan); - - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()) - .Throws(new StripeException - { - StripeError = new StripeError { Code = StripeConstants.ErrorCodes.CustomerTaxLocationInvalid } - }); - - // Act - var result = await _command.Run(parameters); - - // Assert - Assert.True(result.IsT1); - var badRequest = result.AsT1; - Assert.Equal(BillingErrorTranslationKeys.CustomerTaxLocationInvalid, badRequest.TranslationKey); - } - - [Fact] - public async Task Run_TaxIdInvalid_BadRequest() - { - // Arrange - var parameters = new OrganizationTrialParameters - { - PlanType = PlanType.EnterpriseAnnually, - ProductType = ProductType.PasswordManager, - TaxInformation = new TaxInformationDTO - { - Country = "US", - PostalCode = "12345" - } - }; - - var plan = StaticStore.GetPlan(parameters.PlanType); - - _pricingClient.GetPlanOrThrow(parameters.PlanType).Returns(plan); - - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()) - .Throws(new StripeException - { - StripeError = new StripeError { Code = StripeConstants.ErrorCodes.TaxIdInvalid } - }); - - // Act - var result = await _command.Run(parameters); - - // Assert - Assert.True(result.IsT1); - var badRequest = result.AsT1; - Assert.Equal(BillingErrorTranslationKeys.TaxIdInvalid, badRequest.TranslationKey); + Assert.Equal("We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support for assistance.", badRequest.Response); } }