From d914ab8a988b3aee0113f5afbb70ac36675b7d72 Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Mon, 14 Jul 2025 12:39:49 -0500 Subject: [PATCH] [PM-23687] Support free organizations on Payment Details page (#6084) * Resolve JSON serialization bug in OneOf converters and organize pricing models * Support free organizations for payment method and billing address flows * Run dotnet format --- ...illingCommand.cs => BaseBillingCommand.cs} | 31 +++- src/Core/Billing/Constants/StripeConstants.cs | 1 + .../CreateBitPayInvoiceForCreditCommand.cs | 7 +- .../Commands/UpdateBillingAddressCommand.cs | 24 +++- .../Commands/UpdatePaymentMethodCommand.cs | 17 ++- .../Commands/VerifyBankAccountCommand.cs | 11 +- .../Billing/Payment/Models/BillingAddress.cs | 3 +- .../Payment/Models/MaskedPaymentMethod.cs | 40 +++--- .../Payment/Models/TokenizedPaymentMethod.cs | 3 +- .../Payment/Queries/GetBillingAddressQuery.cs | 3 +- .../Billing/Payment/Queries/GetCreditQuery.cs | 3 +- .../Payment/Queries/GetPaymentMethodQuery.cs | 8 +- .../JSON/FreeOrScalableDTOJsonConverter.cs | 35 ----- .../JSON/PurchasableDTOJsonConverter.cs | 40 ------ .../Pricing/JSON/TypeReadingJsonConverter.cs | 36 ----- src/Core/Billing/Pricing/Models/Feature.cs | 7 + src/Core/Billing/Pricing/Models/FeatureDTO.cs | 9 -- src/Core/Billing/Pricing/Models/Plan.cs | 25 ++++ src/Core/Billing/Pricing/Models/PlanDTO.cs | 27 ---- .../Billing/Pricing/Models/Purchasable.cs | 135 ++++++++++++++++++ .../Billing/Pricing/Models/PurchasableDTO.cs | 73 ---------- src/Core/Billing/Pricing/PlanAdapter.cs | 40 +++--- src/Core/Billing/Pricing/PricingClient.cs | 5 +- .../Billing/Services/ISubscriberService.cs | 3 + .../Implementations/SubscriberService.cs | 113 ++++++++++++++- .../Tax/Commands/PreviewTaxAmountCommand.cs | 8 +- .../UpdateBillingAddressCommandTests.cs | 66 ++++++++- .../UpdatePaymentMethodCommandTests.cs | 79 +++++++++- .../Models/MaskedPaymentMethodTests.cs | 21 +++ .../Queries/GetPaymentMethodQueryTests.cs | 18 +++ 30 files changed, 575 insertions(+), 316 deletions(-) rename src/Core/Billing/Commands/{BillingCommand.cs => BaseBillingCommand.cs} (60%) delete mode 100644 src/Core/Billing/Pricing/JSON/FreeOrScalableDTOJsonConverter.cs delete mode 100644 src/Core/Billing/Pricing/JSON/PurchasableDTOJsonConverter.cs delete mode 100644 src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs create mode 100644 src/Core/Billing/Pricing/Models/Feature.cs delete mode 100644 src/Core/Billing/Pricing/Models/FeatureDTO.cs create mode 100644 src/Core/Billing/Pricing/Models/Plan.cs delete mode 100644 src/Core/Billing/Pricing/Models/PlanDTO.cs create mode 100644 src/Core/Billing/Pricing/Models/Purchasable.cs delete mode 100644 src/Core/Billing/Pricing/Models/PurchasableDTO.cs diff --git a/src/Core/Billing/Commands/BillingCommand.cs b/src/Core/Billing/Commands/BaseBillingCommand.cs similarity index 60% rename from src/Core/Billing/Commands/BillingCommand.cs rename to src/Core/Billing/Commands/BaseBillingCommand.cs index e6c6375b62..b3e938548d 100644 --- a/src/Core/Billing/Commands/BillingCommand.cs +++ b/src/Core/Billing/Commands/BaseBillingCommand.cs @@ -1,4 +1,5 @@ using Bit.Core.Billing.Constants; +using Bit.Core.Exceptions; using Microsoft.Extensions.Logging; using Stripe; @@ -6,11 +7,17 @@ namespace Bit.Core.Billing.Commands; using static StripeConstants; -public abstract class BillingCommand( +public abstract class BaseBillingCommand( ILogger logger) { protected string CommandName => GetType().Name; + /// + /// Override this property to set a client-facing conflict response in the case a is thrown + /// during the command's execution. + /// + protected virtual Conflict? DefaultConflict => null; + /// /// Executes the provided function within a predefined execution context, handling any exceptions that occur during the process. /// @@ -29,23 +36,35 @@ public abstract class BillingCommand( return stripeException.StripeError.Code switch { ErrorCodes.CustomerTaxLocationInvalid => - new BadRequest("Your location wasn't recognized. Please ensure your country and postal code are valid and try again."), + new BadRequest( + "Your location wasn't recognized. Please ensure your country and postal code are valid and try again."), ErrorCodes.PaymentMethodMicroDepositVerificationAttemptsExceeded => - new BadRequest("You have exceeded the number of allowed verification attempts. Please contact support for assistance."), + new BadRequest( + "You have exceeded the number of allowed verification attempts. Please contact support for assistance."), ErrorCodes.PaymentMethodMicroDepositVerificationDescriptorCodeMismatch => - new BadRequest("The verification code you provided does not match the one sent to your bank account. Please try again."), + new BadRequest( + "The verification code you provided does not match the one sent to your bank account. Please try again."), ErrorCodes.PaymentMethodMicroDepositVerificationTimeout => - new BadRequest("Your bank account was not verified within the required time period. Please contact support for assistance."), + new BadRequest( + "Your bank account was not verified within the required time period. Please contact support for assistance."), ErrorCodes.TaxIdInvalid => - new BadRequest("The tax ID number you provided was invalid. Please try again or contact support for assistance."), + new BadRequest( + "The tax ID number you provided was invalid. Please try again or contact support for assistance."), _ => new Unhandled(stripeException) }; } + catch (ConflictException conflictException) + { + logger.LogError("{Command}: {Message}", CommandName, conflictException.Message); + return DefaultConflict != null ? + DefaultConflict : + new Unhandled(conflictException); + } catch (StripeException stripeException) { logger.LogError(stripeException, diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index 3aaa519d66..6ecfb4d28b 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -60,6 +60,7 @@ public static class StripeConstants public const string InvoiceApproved = "invoice_approved"; public const string OrganizationId = "organizationId"; public const string ProviderId = "providerId"; + public const string Region = "region"; public const string RetiredBraintreeCustomerId = "btCustomerId_old"; public const string UserId = "userId"; } diff --git a/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs b/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs index f61fa9d0f9..a86f0e3ada 100644 --- a/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs +++ b/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Commands; using Bit.Core.Billing.Payment.Clients; @@ -21,8 +20,10 @@ public interface ICreateBitPayInvoiceForCreditCommand public class CreateBitPayInvoiceForCreditCommand( IBitPayClient bitPayClient, GlobalSettings globalSettings, - ILogger logger) : BillingCommand(logger), ICreateBitPayInvoiceForCreditCommand + ILogger logger) : BaseBillingCommand(logger), ICreateBitPayInvoiceForCreditCommand { + protected override Conflict DefaultConflict => new("We had a problem applying your account credit. Please contact support for assistance."); + public Task> Run( ISubscriber subscriber, decimal amount, diff --git a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs index adc534bd7d..fdf519523a 100644 --- a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs @@ -1,8 +1,8 @@ -#nullable enable -using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Services; using Microsoft.Extensions.Logging; @@ -19,14 +19,26 @@ public interface IUpdateBillingAddressCommand public class UpdateBillingAddressCommand( ILogger logger, - IStripeAdapter stripeAdapter) : BillingCommand(logger), IUpdateBillingAddressCommand + ISubscriberService subscriberService, + IStripeAdapter stripeAdapter) : BaseBillingCommand(logger), IUpdateBillingAddressCommand { + protected override Conflict DefaultConflict => + new("We had a problem updating your billing address. Please contact support for assistance."); + public Task> Run( ISubscriber subscriber, - BillingAddress billingAddress) => HandleAsync(() => subscriber.GetProductUsageType() switch + BillingAddress billingAddress) => HandleAsync(async () => { - ProductUsageType.Personal => UpdatePersonalBillingAddressAsync(subscriber, billingAddress), - ProductUsageType.Business => UpdateBusinessBillingAddressAsync(subscriber, billingAddress) + if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) + { + await subscriberService.CreateStripeCustomer(subscriber); + } + + return subscriber.GetProductUsageType() switch + { + ProductUsageType.Personal => await UpdatePersonalBillingAddressAsync(subscriber, billingAddress), + ProductUsageType.Business => await UpdateBusinessBillingAddressAsync(subscriber, billingAddress) + }; }); private async Task> UpdatePersonalBillingAddressAsync( diff --git a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs index cda685d520..81206b8032 100644 --- a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Caches; using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Models; @@ -29,16 +28,22 @@ public class UpdatePaymentMethodCommand( ILogger logger, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ISubscriberService subscriberService) : BillingCommand(logger), IUpdatePaymentMethodCommand + ISubscriberService subscriberService) : BaseBillingCommand(logger), IUpdatePaymentMethodCommand { private readonly ILogger _logger = logger; - private static readonly Conflict _conflict = new("We had a problem updating your payment method. Please contact support for assistance."); + protected override Conflict DefaultConflict + => new("We had a problem updating your payment method. Please contact support for assistance."); public Task> Run( ISubscriber subscriber, TokenizedPaymentMethod paymentMethod, BillingAddress? billingAddress) => HandleAsync(async () => { + if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) + { + await subscriberService.CreateStripeCustomer(subscriber); + } + var customer = await subscriberService.GetCustomer(subscriber); var result = paymentMethod.Type switch @@ -80,10 +85,10 @@ public class UpdatePaymentMethodCommand( { case 0: _logger.LogError("{Command}: Could not find setup intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id); - return _conflict; + return DefaultConflict; case > 1: _logger.LogError("{Command}: Found more than one set up intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id); - return _conflict; + return DefaultConflict; } var setupIntent = setupIntents.First(); diff --git a/src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs b/src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs index 1e9492b876..4f3e38707c 100644 --- a/src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs +++ b/src/Core/Billing/Payment/Commands/VerifyBankAccountCommand.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Caches; using Bit.Core.Billing.Commands; using Bit.Core.Billing.Payment.Models; using Bit.Core.Entities; @@ -19,12 +18,12 @@ public interface IVerifyBankAccountCommand public class VerifyBankAccountCommand( ILogger logger, ISetupIntentCache setupIntentCache, - IStripeAdapter stripeAdapter) : BillingCommand(logger), IVerifyBankAccountCommand + IStripeAdapter stripeAdapter) : BaseBillingCommand(logger), IVerifyBankAccountCommand { private readonly ILogger _logger = logger; - private static readonly Conflict _conflict = - new("We had a problem verifying your bank account. Please contact support for assistance."); + protected override Conflict DefaultConflict + => new("We had a problem verifying your bank account. Please contact support for assistance."); public Task> Run( ISubscriber subscriber, @@ -37,7 +36,7 @@ public class VerifyBankAccountCommand( _logger.LogError( "{Command}: Could not find setup intent to verify subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id); - return _conflict; + return DefaultConflict; } await stripeAdapter.SetupIntentVerifyMicroDeposit(setupIntentId, diff --git a/src/Core/Billing/Payment/Models/BillingAddress.cs b/src/Core/Billing/Payment/Models/BillingAddress.cs index 5c2c43231c..39dd1f4121 100644 --- a/src/Core/Billing/Payment/Models/BillingAddress.cs +++ b/src/Core/Billing/Payment/Models/BillingAddress.cs @@ -1,5 +1,4 @@ -#nullable enable -using Stripe; +using Stripe; namespace Bit.Core.Billing.Payment.Models; diff --git a/src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs b/src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs index c98fddc785..d23ca75025 100644 --- a/src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs +++ b/src/Core/Billing/Payment/Models/MaskedPaymentMethod.cs @@ -1,7 +1,5 @@ -#nullable enable -using System.Text.Json; +using System.Text.Json; using System.Text.Json.Serialization; -using Bit.Core.Billing.Pricing.JSON; using Braintree; using OneOf; using Stripe; @@ -83,32 +81,28 @@ public class MaskedPaymentMethod(OneOf new MaskedPayPalAccount { Email = payPalAccount.Email }; } -public class MaskedPaymentMethodJsonConverter : TypeReadingJsonConverter +public class MaskedPaymentMethodJsonConverter : JsonConverter { - protected override string TypePropertyName => nameof(MaskedBankAccount.Type).ToLower(); + private const string _typePropertyName = nameof(MaskedBankAccount.Type); - public override MaskedPaymentMethod? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override MaskedPaymentMethod Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { - var type = ReadType(reader); + var element = JsonElement.ParseValue(ref reader); + + if (!element.TryGetProperty(options.PropertyNamingPolicy?.ConvertName(_typePropertyName) ?? _typePropertyName, out var typeProperty)) + { + throw new JsonException( + $"Failed to deserialize {nameof(MaskedPaymentMethod)}: missing '{_typePropertyName}' property"); + } + + var type = typeProperty.GetString(); return type switch { - "bankAccount" => JsonSerializer.Deserialize(ref reader, options) switch - { - null => null, - var bankAccount => new MaskedPaymentMethod(bankAccount) - }, - "card" => JsonSerializer.Deserialize(ref reader, options) switch - { - null => null, - var card => new MaskedPaymentMethod(card) - }, - "payPal" => JsonSerializer.Deserialize(ref reader, options) switch - { - null => null, - var payPal => new MaskedPaymentMethod(payPal) - }, - _ => Skip(ref reader) + "bankAccount" => element.Deserialize(options)!, + "card" => element.Deserialize(options)!, + "payPal" => element.Deserialize(options)!, + _ => throw new JsonException($"Failed to deserialize {nameof(MaskedPaymentMethod)}: invalid '{_typePropertyName}' value - '{type}'") }; } diff --git a/src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs b/src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs index edbf1bb121..9af7c9888a 100644 --- a/src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs +++ b/src/Core/Billing/Payment/Models/TokenizedPaymentMethod.cs @@ -1,5 +1,4 @@ -#nullable enable -namespace Bit.Core.Billing.Payment.Models; +namespace Bit.Core.Billing.Payment.Models; public record TokenizedPaymentMethod { diff --git a/src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs b/src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs index 84d4d4f377..e49c2cc993 100644 --- a/src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs +++ b/src/Core/Billing/Payment/Queries/GetBillingAddressQuery.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; diff --git a/src/Core/Billing/Payment/Queries/GetCreditQuery.cs b/src/Core/Billing/Payment/Queries/GetCreditQuery.cs index 79c9a13aba..81d560269b 100644 --- a/src/Core/Billing/Payment/Queries/GetCreditQuery.cs +++ b/src/Core/Billing/Payment/Queries/GetCreditQuery.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services; using Bit.Core.Entities; namespace Bit.Core.Billing.Payment.Queries; diff --git a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs index eb42a8c78a..ce8f031a5d 100644 --- a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs +++ b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; @@ -29,6 +28,11 @@ public class GetPaymentMethodQuery( var customer = await subscriberService.GetCustomer(subscriber, new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] }); + if (customer == null) + { + return null; + } + if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId)) { var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); diff --git a/src/Core/Billing/Pricing/JSON/FreeOrScalableDTOJsonConverter.cs b/src/Core/Billing/Pricing/JSON/FreeOrScalableDTOJsonConverter.cs deleted file mode 100644 index 37a8a4234d..0000000000 --- a/src/Core/Billing/Pricing/JSON/FreeOrScalableDTOJsonConverter.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System.Text.Json; -using Bit.Core.Billing.Pricing.Models; - -namespace Bit.Core.Billing.Pricing.JSON; - -#nullable enable - -public class FreeOrScalableDTOJsonConverter : TypeReadingJsonConverter -{ - public override FreeOrScalableDTO? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - var type = ReadType(reader); - - return type switch - { - "free" => JsonSerializer.Deserialize(ref reader, options) switch - { - null => null, - var free => new FreeOrScalableDTO(free) - }, - "scalable" => JsonSerializer.Deserialize(ref reader, options) switch - { - null => null, - var scalable => new FreeOrScalableDTO(scalable) - }, - _ => null - }; - } - - public override void Write(Utf8JsonWriter writer, FreeOrScalableDTO value, JsonSerializerOptions options) - => value.Switch( - free => JsonSerializer.Serialize(writer, free, options), - scalable => JsonSerializer.Serialize(writer, scalable, options) - ); -} diff --git a/src/Core/Billing/Pricing/JSON/PurchasableDTOJsonConverter.cs b/src/Core/Billing/Pricing/JSON/PurchasableDTOJsonConverter.cs deleted file mode 100644 index f7ae9dc472..0000000000 --- a/src/Core/Billing/Pricing/JSON/PurchasableDTOJsonConverter.cs +++ /dev/null @@ -1,40 +0,0 @@ -using System.Text.Json; -using Bit.Core.Billing.Pricing.Models; - -namespace Bit.Core.Billing.Pricing.JSON; - -#nullable enable -internal class PurchasableDTOJsonConverter : TypeReadingJsonConverter -{ - public override PurchasableDTO? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - var type = ReadType(reader); - - return type switch - { - "free" => JsonSerializer.Deserialize(ref reader, options) switch - { - null => null, - var free => new PurchasableDTO(free) - }, - "packaged" => JsonSerializer.Deserialize(ref reader, options) switch - { - null => null, - var packaged => new PurchasableDTO(packaged) - }, - "scalable" => JsonSerializer.Deserialize(ref reader, options) switch - { - null => null, - var scalable => new PurchasableDTO(scalable) - }, - _ => null - }; - } - - public override void Write(Utf8JsonWriter writer, PurchasableDTO value, JsonSerializerOptions options) - => value.Switch( - free => JsonSerializer.Serialize(writer, free, options), - packaged => JsonSerializer.Serialize(writer, packaged, options), - scalable => JsonSerializer.Serialize(writer, scalable, options) - ); -} diff --git a/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs b/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs deleted file mode 100644 index 05beccdb60..0000000000 --- a/src/Core/Billing/Pricing/JSON/TypeReadingJsonConverter.cs +++ /dev/null @@ -1,36 +0,0 @@ -using System.Text.Json; -using System.Text.Json.Serialization; -using Bit.Core.Billing.Pricing.Models; - -namespace Bit.Core.Billing.Pricing.JSON; - -#nullable enable - -public abstract class TypeReadingJsonConverter : JsonConverter where T : class -{ - protected virtual string TypePropertyName => nameof(ScalableDTO.Type).ToLower(); - - protected string? ReadType(Utf8JsonReader reader) - { - while (reader.Read()) - { - if (reader.CurrentDepth != 1 || - reader.TokenType != JsonTokenType.PropertyName || - reader.GetString()?.ToLower() != TypePropertyName) - { - continue; - } - - reader.Read(); - return reader.GetString(); - } - - return null; - } - - protected T? Skip(ref Utf8JsonReader reader) - { - reader.Skip(); - return null; - } -} diff --git a/src/Core/Billing/Pricing/Models/Feature.cs b/src/Core/Billing/Pricing/Models/Feature.cs new file mode 100644 index 0000000000..ea9da5217d --- /dev/null +++ b/src/Core/Billing/Pricing/Models/Feature.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.Billing.Pricing.Models; + +public class Feature +{ + public required string Name { get; set; } + public required string LookupKey { get; set; } +} diff --git a/src/Core/Billing/Pricing/Models/FeatureDTO.cs b/src/Core/Billing/Pricing/Models/FeatureDTO.cs deleted file mode 100644 index a96ac019e3..0000000000 --- a/src/Core/Billing/Pricing/Models/FeatureDTO.cs +++ /dev/null @@ -1,9 +0,0 @@ -namespace Bit.Core.Billing.Pricing.Models; - -#nullable enable - -public class FeatureDTO -{ - public string Name { get; set; } = null!; - public string LookupKey { get; set; } = null!; -} diff --git a/src/Core/Billing/Pricing/Models/Plan.cs b/src/Core/Billing/Pricing/Models/Plan.cs new file mode 100644 index 0000000000..5b4296474b --- /dev/null +++ b/src/Core/Billing/Pricing/Models/Plan.cs @@ -0,0 +1,25 @@ +namespace Bit.Core.Billing.Pricing.Models; + +public class Plan +{ + public required string LookupKey { get; set; } + public required string Name { get; set; } + public required string Tier { get; set; } + public string? Cadence { get; set; } + public int? LegacyYear { get; set; } + public bool Available { get; set; } + public required Feature[] Features { get; set; } + public required Purchasable Seats { get; set; } + public Scalable? ManagedSeats { get; set; } + public Scalable? Storage { get; set; } + public SecretsManagerPurchasables? SecretsManager { get; set; } + public int? TrialPeriodDays { get; set; } + public required string[] CanUpgradeTo { get; set; } + public required Dictionary AdditionalData { get; set; } +} + +public class SecretsManagerPurchasables +{ + public required FreeOrScalable Seats { get; set; } + public required FreeOrScalable ServiceAccounts { get; set; } +} diff --git a/src/Core/Billing/Pricing/Models/PlanDTO.cs b/src/Core/Billing/Pricing/Models/PlanDTO.cs deleted file mode 100644 index 4ae82b3efe..0000000000 --- a/src/Core/Billing/Pricing/Models/PlanDTO.cs +++ /dev/null @@ -1,27 +0,0 @@ -namespace Bit.Core.Billing.Pricing.Models; - -#nullable enable - -public class PlanDTO -{ - public string LookupKey { get; set; } = null!; - public string Name { get; set; } = null!; - public string Tier { get; set; } = null!; - public string? Cadence { get; set; } - public int? LegacyYear { get; set; } - public bool Available { get; set; } - public FeatureDTO[] Features { get; set; } = null!; - public PurchasableDTO Seats { get; set; } = null!; - public ScalableDTO? ManagedSeats { get; set; } - public ScalableDTO? Storage { get; set; } - public SecretsManagerPurchasablesDTO? SecretsManager { get; set; } - public int? TrialPeriodDays { get; set; } - public string[] CanUpgradeTo { get; set; } = null!; - public Dictionary AdditionalData { get; set; } = null!; -} - -public class SecretsManagerPurchasablesDTO -{ - public FreeOrScalableDTO Seats { get; set; } = null!; - public FreeOrScalableDTO ServiceAccounts { get; set; } = null!; -} diff --git a/src/Core/Billing/Pricing/Models/Purchasable.cs b/src/Core/Billing/Pricing/Models/Purchasable.cs new file mode 100644 index 0000000000..7cb4ee00c1 --- /dev/null +++ b/src/Core/Billing/Pricing/Models/Purchasable.cs @@ -0,0 +1,135 @@ +using System.Text.Json; +using System.Text.Json.Serialization; +using OneOf; + +namespace Bit.Core.Billing.Pricing.Models; + +[JsonConverter(typeof(PurchasableJsonConverter))] +public class Purchasable(OneOf input) : OneOfBase(input) +{ + public static implicit operator Purchasable(Free free) => new(free); + public static implicit operator Purchasable(Packaged packaged) => new(packaged); + public static implicit operator Purchasable(Scalable scalable) => new(scalable); + + public T? FromFree(Func select, Func? fallback = null) => + IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default; + + public T? FromPackaged(Func select, Func? fallback = null) => + IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default; + + public T? FromScalable(Func select, Func? fallback = null) => + IsT2 ? select(AsT2) : fallback != null ? fallback(this) : default; + + public bool IsFree => IsT0; + public bool IsPackaged => IsT1; + public bool IsScalable => IsT2; +} + +internal class PurchasableJsonConverter : JsonConverter +{ + private static readonly string _typePropertyName = nameof(Free.Type).ToLower(); + + public override Purchasable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + + if (!element.TryGetProperty(options.PropertyNamingPolicy?.ConvertName(_typePropertyName) ?? _typePropertyName, out var typeProperty)) + { + throw new JsonException( + $"Failed to deserialize {nameof(Purchasable)}: missing '{_typePropertyName}' property"); + } + + var type = typeProperty.GetString(); + + return type switch + { + "free" => element.Deserialize(options)!, + "packaged" => element.Deserialize(options)!, + "scalable" => element.Deserialize(options)!, + _ => throw new JsonException($"Failed to deserialize {nameof(Purchasable)}: invalid '{_typePropertyName}' value - '{type}'"), + }; + } + + public override void Write(Utf8JsonWriter writer, Purchasable value, JsonSerializerOptions options) + => value.Switch( + free => JsonSerializer.Serialize(writer, free, options), + packaged => JsonSerializer.Serialize(writer, packaged, options), + scalable => JsonSerializer.Serialize(writer, scalable, options) + ); +} + +[JsonConverter(typeof(FreeOrScalableJsonConverter))] +public class FreeOrScalable(OneOf input) : OneOfBase(input) +{ + public static implicit operator FreeOrScalable(Free free) => new(free); + public static implicit operator FreeOrScalable(Scalable scalable) => new(scalable); + + public T? FromFree(Func select, Func? fallback = null) => + IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default; + + public T? FromScalable(Func select, Func? fallback = null) => + IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default; + + public bool IsFree => IsT0; + public bool IsScalable => IsT1; +} + +public class FreeOrScalableJsonConverter : JsonConverter +{ + private static readonly string _typePropertyName = nameof(Free.Type).ToLower(); + + public override FreeOrScalable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + + if (!element.TryGetProperty(options.PropertyNamingPolicy?.ConvertName(_typePropertyName) ?? _typePropertyName, out var typeProperty)) + { + throw new JsonException( + $"Failed to deserialize {nameof(FreeOrScalable)}: missing '{_typePropertyName}' property"); + } + + var type = typeProperty.GetString(); + + return type switch + { + "free" => element.Deserialize(options)!, + "scalable" => element.Deserialize(options)!, + _ => throw new JsonException($"Failed to deserialize {nameof(FreeOrScalable)}: invalid '{_typePropertyName}' value - '{type}'"), + }; + } + + public override void Write(Utf8JsonWriter writer, FreeOrScalable value, JsonSerializerOptions options) + => value.Switch( + free => JsonSerializer.Serialize(writer, free, options), + scalable => JsonSerializer.Serialize(writer, scalable, options) + ); +} + +public class Free +{ + public int Quantity { get; set; } + public string Type => "free"; +} + +public class Packaged +{ + public int Quantity { get; set; } + public string StripePriceId { get; set; } = null!; + public decimal Price { get; set; } + public AdditionalSeats? Additional { get; set; } + public string Type => "packaged"; + + public class AdditionalSeats + { + public string StripePriceId { get; set; } = null!; + public decimal Price { get; set; } + } +} + +public class Scalable +{ + public int Provided { get; set; } + public string StripePriceId { get; set; } = null!; + public decimal Price { get; set; } + public string Type => "scalable"; +} diff --git a/src/Core/Billing/Pricing/Models/PurchasableDTO.cs b/src/Core/Billing/Pricing/Models/PurchasableDTO.cs deleted file mode 100644 index 8ba1c7b731..0000000000 --- a/src/Core/Billing/Pricing/Models/PurchasableDTO.cs +++ /dev/null @@ -1,73 +0,0 @@ -using System.Text.Json.Serialization; -using Bit.Core.Billing.Pricing.JSON; -using OneOf; - -namespace Bit.Core.Billing.Pricing.Models; - -#nullable enable - -[JsonConverter(typeof(PurchasableDTOJsonConverter))] -public class PurchasableDTO(OneOf input) : OneOfBase(input) -{ - public static implicit operator PurchasableDTO(FreeDTO free) => new(free); - public static implicit operator PurchasableDTO(PackagedDTO packaged) => new(packaged); - public static implicit operator PurchasableDTO(ScalableDTO scalable) => new(scalable); - - public T? FromFree(Func select, Func? fallback = null) => - IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default; - - public T? FromPackaged(Func select, Func? fallback = null) => - IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default; - - public T? FromScalable(Func select, Func? fallback = null) => - IsT2 ? select(AsT2) : fallback != null ? fallback(this) : default; - - public bool IsFree => IsT0; - public bool IsPackaged => IsT1; - public bool IsScalable => IsT2; -} - -[JsonConverter(typeof(FreeOrScalableDTOJsonConverter))] -public class FreeOrScalableDTO(OneOf input) : OneOfBase(input) -{ - public static implicit operator FreeOrScalableDTO(FreeDTO freeDTO) => new(freeDTO); - public static implicit operator FreeOrScalableDTO(ScalableDTO scalableDTO) => new(scalableDTO); - - public T? FromFree(Func select, Func? fallback = null) => - IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default; - - public T? FromScalable(Func select, Func? fallback = null) => - IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default; - - public bool IsFree => IsT0; - public bool IsScalable => IsT1; -} - -public class FreeDTO -{ - public int Quantity { get; set; } - public string Type => "free"; -} - -public class PackagedDTO -{ - public int Quantity { get; set; } - public string StripePriceId { get; set; } = null!; - public decimal Price { get; set; } - public AdditionalSeats? Additional { get; set; } - public string Type => "packaged"; - - public class AdditionalSeats - { - public string StripePriceId { get; set; } = null!; - public decimal Price { get; set; } - } -} - -public class ScalableDTO -{ - public int Provided { get; set; } - public string StripePriceId { get; set; } = null!; - public decimal Price { get; set; } - public string Type => "scalable"; -} diff --git a/src/Core/Billing/Pricing/PlanAdapter.cs b/src/Core/Billing/Pricing/PlanAdapter.cs index 45a48c3f80..560987b891 100644 --- a/src/Core/Billing/Pricing/PlanAdapter.cs +++ b/src/Core/Billing/Pricing/PlanAdapter.cs @@ -1,14 +1,12 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing.Models; -using Bit.Core.Models.StaticStore; - -#nullable enable +using Plan = Bit.Core.Billing.Pricing.Models.Plan; namespace Bit.Core.Billing.Pricing; -public record PlanAdapter : Plan +public record PlanAdapter : Core.Models.StaticStore.Plan { - public PlanAdapter(PlanDTO plan) + public PlanAdapter(Plan plan) { Type = ToPlanType(plan.LookupKey); ProductTier = ToProductTierType(Type); @@ -88,7 +86,7 @@ public record PlanAdapter : Plan _ => throw new BillingException() // TODO: Flesh out }; - private static PasswordManagerPlanFeatures ToPasswordManagerPlanFeatures(PlanDTO plan) + private static PasswordManagerPlanFeatures ToPasswordManagerPlanFeatures(Plan plan) { var stripePlanId = GetStripePlanId(plan.Seats); var stripeSeatPlanId = GetStripeSeatPlanId(plan.Seats); @@ -128,7 +126,7 @@ public record PlanAdapter : Plan }; } - private static SecretsManagerPlanFeatures ToSecretsManagerPlanFeatures(PlanDTO plan) + private static SecretsManagerPlanFeatures ToSecretsManagerPlanFeatures(Plan plan) { var seats = plan.SecretsManager!.Seats; var serviceAccounts = plan.SecretsManager.ServiceAccounts; @@ -165,62 +163,62 @@ public record PlanAdapter : Plan }; } - private static decimal? GetAdditionalPricePerServiceAccount(FreeOrScalableDTO freeOrScalable) + private static decimal? GetAdditionalPricePerServiceAccount(FreeOrScalable freeOrScalable) => freeOrScalable.FromScalable(x => x.Price); - private static decimal GetBasePrice(PurchasableDTO purchasable) + private static decimal GetBasePrice(Purchasable purchasable) => purchasable.FromPackaged(x => x.Price); - private static int GetBaseSeats(FreeOrScalableDTO freeOrScalable) + private static int GetBaseSeats(FreeOrScalable freeOrScalable) => freeOrScalable.Match( free => free.Quantity, scalable => scalable.Provided); - private static int GetBaseSeats(PurchasableDTO purchasable) + private static int GetBaseSeats(Purchasable purchasable) => purchasable.Match( free => free.Quantity, packaged => packaged.Quantity, scalable => scalable.Provided); - private static short GetBaseServiceAccount(FreeOrScalableDTO freeOrScalable) + private static short GetBaseServiceAccount(FreeOrScalable freeOrScalable) => freeOrScalable.Match( free => (short)free.Quantity, scalable => (short)scalable.Provided); - private static short? GetMaxSeats(PurchasableDTO purchasable) + private static short? GetMaxSeats(Purchasable purchasable) => purchasable.Match( free => (short)free.Quantity, packaged => (short)packaged.Quantity, _ => null); - private static short? GetMaxSeats(FreeOrScalableDTO freeOrScalable) + private static short? GetMaxSeats(FreeOrScalable freeOrScalable) => freeOrScalable.FromFree(x => (short)x.Quantity); - private static short? GetMaxServiceAccounts(FreeOrScalableDTO freeOrScalable) + private static short? GetMaxServiceAccounts(FreeOrScalable freeOrScalable) => freeOrScalable.FromFree(x => (short)x.Quantity); - private static decimal GetSeatPrice(PurchasableDTO purchasable) + private static decimal GetSeatPrice(Purchasable purchasable) => purchasable.Match( _ => 0, packaged => packaged.Additional?.Price ?? 0, scalable => scalable.Price); - private static decimal GetSeatPrice(FreeOrScalableDTO freeOrScalable) + private static decimal GetSeatPrice(FreeOrScalable freeOrScalable) => freeOrScalable.FromScalable(x => x.Price); - private static string? GetStripePlanId(PurchasableDTO purchasable) + private static string? GetStripePlanId(Purchasable purchasable) => purchasable.FromPackaged(x => x.StripePriceId); - private static string? GetStripeSeatPlanId(PurchasableDTO purchasable) + private static string? GetStripeSeatPlanId(Purchasable purchasable) => purchasable.Match( _ => null, packaged => packaged.Additional?.StripePriceId, scalable => scalable.StripePriceId); - private static string? GetStripeSeatPlanId(FreeOrScalableDTO freeOrScalable) + private static string? GetStripeSeatPlanId(FreeOrScalable freeOrScalable) => freeOrScalable.FromScalable(x => x.StripePriceId); - private static string? GetStripeServiceAccountPlanId(FreeOrScalableDTO freeOrScalable) + private static string? GetStripeServiceAccountPlanId(FreeOrScalable freeOrScalable) => freeOrScalable.FromScalable(x => x.StripePriceId); #endregion diff --git a/src/Core/Billing/Pricing/PricingClient.cs b/src/Core/Billing/Pricing/PricingClient.cs index 14caa54eb4..a3db8ce07f 100644 --- a/src/Core/Billing/Pricing/PricingClient.cs +++ b/src/Core/Billing/Pricing/PricingClient.cs @@ -1,7 +1,6 @@ using System.Net; using System.Net.Http.Json; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Pricing.Models; using Bit.Core.Exceptions; using Bit.Core.Services; using Bit.Core.Settings; @@ -45,7 +44,7 @@ public class PricingClient( if (response.IsSuccessStatusCode) { - var plan = await response.Content.ReadFromJsonAsync(); + var plan = await response.Content.ReadFromJsonAsync(); if (plan == null) { throw new BillingException(message: "Deserialization of Pricing Service response resulted in null"); @@ -93,7 +92,7 @@ public class PricingClient( if (response.IsSuccessStatusCode) { - var plans = await response.Content.ReadFromJsonAsync>(); + var plans = await response.Content.ReadFromJsonAsync>(); if (plans == null) { throw new BillingException(message: "Deserialization of Pricing Service response resulted in null"); diff --git a/src/Core/Billing/Services/ISubscriberService.cs b/src/Core/Billing/Services/ISubscriberService.cs index ef43bde010..5f656b2c22 100644 --- a/src/Core/Billing/Services/ISubscriberService.cs +++ b/src/Core/Billing/Services/ISubscriberService.cs @@ -36,6 +36,9 @@ public interface ISubscriberService ISubscriber subscriber, string paymentMethodNonce); + Task CreateStripeCustomer( + ISubscriber subscriber); + /// /// Retrieves a Stripe using the 's property. /// diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 7a0e78a6dc..73696846ac 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; @@ -13,6 +14,7 @@ using Bit.Core.Billing.Tax.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; @@ -27,14 +29,19 @@ using Subscription = Stripe.Subscription; namespace Bit.Core.Billing.Services.Implementations; +using static StripeConstants; + public class SubscriberService( IBraintreeGateway braintreeGateway, IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, + IOrganizationRepository organizationRepository, + IProviderRepository providerRepository, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ITaxService taxService) : ISubscriberService + ITaxService taxService, + IUserRepository userRepository) : ISubscriberService { public async Task CancelSubscription( ISubscriber subscriber, @@ -146,6 +153,110 @@ public class SubscriberService( throw new BillingException(); } +#nullable enable + public async Task CreateStripeCustomer(ISubscriber subscriber) + { + if (!string.IsNullOrEmpty(subscriber.GatewayCustomerId)) + { + throw new ConflictException("Subscriber already has a linked Stripe Customer"); + } + + var options = subscriber switch + { + Organization organization => new CustomerCreateOptions + { + Description = organization.DisplayBusinessName(), + Email = organization.BillingEmail, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = + [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = organization.SubscriberType(), + Value = Max30Characters(organization.DisplayName()) + } + ] + }, + Metadata = new Dictionary + { + [MetadataKeys.OrganizationId] = organization.Id.ToString(), + [MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion + } + }, + Provider provider => new CustomerCreateOptions + { + Description = provider.DisplayBusinessName(), + Email = provider.BillingEmail, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = + [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = provider.SubscriberType(), + Value = Max30Characters(provider.DisplayName()) + } + ] + }, + Metadata = new Dictionary + { + [MetadataKeys.ProviderId] = provider.Id.ToString(), + [MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion + } + }, + User user => new CustomerCreateOptions + { + Description = user.Name, + Email = user.Email, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = + [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = user.SubscriberType(), + Value = Max30Characters(user.SubscriberName()) + } + ] + }, + Metadata = new Dictionary + { + [MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion, + [MetadataKeys.UserId] = user.Id.ToString() + } + }, + _ => throw new ArgumentOutOfRangeException(nameof(subscriber)) + }; + + var customer = await stripeAdapter.CustomerCreateAsync(options); + + switch (subscriber) + { + case Organization organization: + organization.Gateway = GatewayType.Stripe; + organization.GatewayCustomerId = customer.Id; + await organizationRepository.ReplaceAsync(organization); + break; + case Provider provider: + provider.Gateway = GatewayType.Stripe; + provider.GatewayCustomerId = customer.Id; + await providerRepository.ReplaceAsync(provider); + break; + case User user: + user.Gateway = GatewayType.Stripe; + user.GatewayCustomerId = customer.Id; + await userRepository.ReplaceAsync(user); + break; + } + + return customer; + + string? Max30Characters(string? input) + => input?.Length <= 30 ? input : input?[..30]; + } +#nullable disable + public async Task GetCustomer( ISubscriber subscriber, CustomerGetOptions customerGetOptions = null) diff --git a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs b/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs index 86f233232f..6e061293c7 100644 --- a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs +++ b/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; @@ -20,8 +19,11 @@ public class PreviewTaxAmountCommand( ILogger logger, IPricingClient pricingClient, IStripeAdapter stripeAdapter, - ITaxService taxService) : BillingCommand(logger), IPreviewTaxAmountCommand + ITaxService taxService) : BaseBillingCommand(logger), IPreviewTaxAmountCommand { + protected override Conflict DefaultConflict + => new("We had a problem calculating your tax obligation. Please contact support for assistance."); + public Task> Run(OrganizationTrialParameters parameters) => HandleAsync(async () => { diff --git a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs index 453d0c78e9..c42049d5bb 100644 --- a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs @@ -3,6 +3,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Services; using Bit.Core.Services; using Bit.Core.Test.Billing.Extensions; using Microsoft.Extensions.Logging; @@ -16,14 +17,15 @@ using static StripeConstants; public class UpdateBillingAddressCommandTests { - private readonly IStripeAdapter _stripeAdapter; + private readonly ISubscriberService _subscriberService = Substitute.For(); + private readonly IStripeAdapter _stripeAdapter = Substitute.For(); private readonly UpdateBillingAddressCommand _command; public UpdateBillingAddressCommandTests() { - _stripeAdapter = Substitute.For(); _command = new UpdateBillingAddressCommand( Substitute.For>(), + _subscriberService, _stripeAdapter); } @@ -86,6 +88,66 @@ public class UpdateBillingAddressCommandTests Arg.Is(options => options.AutomaticTax.Enabled == true)); } + [Fact] + public async Task Run_PersonalOrganization_NoCurrentCustomer_MakesCorrectInvocations_ReturnsBillingAddress() + { + var organization = new Organization + { + PlanType = PlanType.FamiliesAnnually, + GatewaySubscriptionId = "sub_123" + }; + + var input = new BillingAddress + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Suite 100", + City = "New York", + State = "NY" + }; + + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Suite 100", + City = "New York", + State = "NY" + }, + Subscriptions = new StripeList + { + Data = + [ + new Subscription + { + Id = organization.GatewaySubscriptionId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + } + }; + + _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Address.Matches(input) && + options.HasExpansions("subscriptions") + )).Returns(customer); + + var result = await _command.Run(organization, input); + + Assert.True(result.IsT0); + var output = result.AsT0; + Assert.Equivalent(input, output); + + await _subscriberService.Received(1).CreateStripeCustomer(organization); + + await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); + } + [Fact] public async Task Run_BusinessOrganization_MakesCorrectInvocations_ReturnsBillingAddress() { diff --git a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs index e7bc5c787c..8b1f915658 100644 --- a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs @@ -45,7 +45,8 @@ public class UpdatePaymentMethodCommandTests { var organization = new Organization { - Id = Guid.NewGuid() + Id = Guid.NewGuid(), + GatewayCustomerId = "cus_123" }; var customer = new Customer @@ -100,13 +101,75 @@ public class UpdatePaymentMethodCommandTests } [Fact] - public async Task Run_BankAccount_StripeToPayPal_MakesCorrectInvocations_ReturnsMaskedBankAccount() + public async Task Run_BankAccount_NoCurrentCustomer_MakesCorrectInvocations_ReturnsMaskedBankAccount() { var organization = new Organization { Id = Guid.NewGuid() }; + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345" + }, + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + const string token = "TOKEN"; + + var setupIntent = new SetupIntent + { + Id = "seti_123", + PaymentMethod = + new PaymentMethod + { + Type = "us_bank_account", + UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } + }, + NextAction = new SetupIntentNextAction + { + VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits() + }, + Status = "requires_action" + }; + + _stripeAdapter.SetupIntentList(Arg.Is(options => + options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); + + var result = await _command.Run(organization, + new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = token }, new BillingAddress + { + Country = "US", + PostalCode = "12345" + }); + + Assert.True(result.IsT0); + var maskedPaymentMethod = result.AsT0; + Assert.True(maskedPaymentMethod.IsT0); + var maskedBankAccount = maskedPaymentMethod.AsT0; + Assert.Equal("Chase", maskedBankAccount.BankName); + Assert.Equal("9999", maskedBankAccount.Last4); + Assert.False(maskedBankAccount.Verified); + + await _subscriberService.Received(1).CreateStripeCustomer(organization); + + await _setupIntentCache.Received(1).Set(organization.Id, setupIntent.Id); + } + + [Fact] + public async Task Run_BankAccount_StripeToPayPal_MakesCorrectInvocations_ReturnsMaskedBankAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid(), + GatewayCustomerId = "cus_123" + }; + var customer = new Customer { Address = new Address @@ -170,7 +233,8 @@ public class UpdatePaymentMethodCommandTests { var organization = new Organization { - Id = Guid.NewGuid() + Id = Guid.NewGuid(), + GatewayCustomerId = "cus_123" }; var customer = new Customer @@ -227,7 +291,8 @@ public class UpdatePaymentMethodCommandTests { var organization = new Organization { - Id = Guid.NewGuid() + Id = Guid.NewGuid(), + GatewayCustomerId = "cus_123" }; var customer = new Customer @@ -282,7 +347,8 @@ public class UpdatePaymentMethodCommandTests { var organization = new Organization { - Id = Guid.NewGuid() + Id = Guid.NewGuid(), + GatewayCustomerId = "cus_123" }; var customer = new Customer @@ -343,7 +409,8 @@ public class UpdatePaymentMethodCommandTests { var organization = new Organization { - Id = Guid.NewGuid() + Id = Guid.NewGuid(), + GatewayCustomerId = "cus_123" }; var customer = new Customer diff --git a/test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs b/test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs index 345f2dfab8..39753857d5 100644 --- a/test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs +++ b/test/Core.Test/Billing/Payment/Models/MaskedPaymentMethodTests.cs @@ -25,6 +25,27 @@ public class MaskedPaymentMethodTests Assert.Equivalent(input.AsT0, output.AsT0); } + [Fact] + public void Write_Read_BankAccount_WithOptions_Succeeds() + { + MaskedPaymentMethod input = new MaskedBankAccount + { + BankName = "Chase", + Last4 = "9999", + Verified = true + }; + + var jsonSerializerOptions = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + + var json = JsonSerializer.Serialize(input, jsonSerializerOptions); + + var output = JsonSerializer.Deserialize(json, jsonSerializerOptions); + Assert.NotNull(output); + Assert.True(output.IsT0); + + Assert.Equivalent(input.AsT0, output.AsT0); + } + [Fact] public void Write_Read_Card_Succeeds() { diff --git a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs index 4d82b4b5c9..8a4475268d 100644 --- a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs +++ b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs @@ -8,6 +8,7 @@ using Bit.Core.Test.Billing.Extensions; using Braintree; using Microsoft.Extensions.Logging; using NSubstitute; +using NSubstitute.ReturnsExtensions; using Stripe; using Xunit; using Customer = Stripe.Customer; @@ -35,6 +36,23 @@ public class GetPaymentMethodQueryTests _subscriberService); } + [Fact] + public async Task Run_NoCustomer_ReturnsNull() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + _subscriberService.GetCustomer(organization, + Arg.Is(options => + options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).ReturnsNull(); + + var maskedPaymentMethod = await _query.Run(organization); + + Assert.Null(maskedPaymentMethod); + } + [Fact] public async Task Run_NoPaymentMethod_ReturnsNull() {