1
0
mirror of https://github.com/bitwarden/server.git synced 2025-07-20 08:57:07 -05:00

[PM-23687] Support free organizations on Payment Details page (#6084)

* Resolve JSON serialization bug in OneOf converters and organize pricing models

* Support free organizations for payment method and billing address flows

* Run dotnet format
This commit is contained in:
Alex Morask
2025-07-14 12:39:49 -05:00
committed by GitHub
parent 0e4e060f22
commit d914ab8a98
30 changed files with 575 additions and 316 deletions

View File

@ -1,4 +1,5 @@
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Exceptions;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Stripe; using Stripe;
@ -6,11 +7,17 @@ namespace Bit.Core.Billing.Commands;
using static StripeConstants; using static StripeConstants;
public abstract class BillingCommand<T>( public abstract class BaseBillingCommand<T>(
ILogger<T> logger) ILogger<T> logger)
{ {
protected string CommandName => GetType().Name; protected string CommandName => GetType().Name;
/// <summary>
/// Override this property to set a client-facing conflict response in the case a <see cref="ConflictException"/> is thrown
/// during the command's execution.
/// </summary>
protected virtual Conflict? DefaultConflict => null;
/// <summary> /// <summary>
/// Executes the provided function within a predefined execution context, handling any exceptions that occur during the process. /// Executes the provided function within a predefined execution context, handling any exceptions that occur during the process.
/// </summary> /// </summary>
@ -29,23 +36,35 @@ public abstract class BillingCommand<T>(
return stripeException.StripeError.Code switch return stripeException.StripeError.Code switch
{ {
ErrorCodes.CustomerTaxLocationInvalid => ErrorCodes.CustomerTaxLocationInvalid =>
new BadRequest("Your location wasn't recognized. Please ensure your country and postal code are valid and try again."), new BadRequest(
"Your location wasn't recognized. Please ensure your country and postal code are valid and try again."),
ErrorCodes.PaymentMethodMicroDepositVerificationAttemptsExceeded => ErrorCodes.PaymentMethodMicroDepositVerificationAttemptsExceeded =>
new BadRequest("You have exceeded the number of allowed verification attempts. Please contact support for assistance."), new BadRequest(
"You have exceeded the number of allowed verification attempts. Please contact support for assistance."),
ErrorCodes.PaymentMethodMicroDepositVerificationDescriptorCodeMismatch => ErrorCodes.PaymentMethodMicroDepositVerificationDescriptorCodeMismatch =>
new BadRequest("The verification code you provided does not match the one sent to your bank account. Please try again."), new BadRequest(
"The verification code you provided does not match the one sent to your bank account. Please try again."),
ErrorCodes.PaymentMethodMicroDepositVerificationTimeout => ErrorCodes.PaymentMethodMicroDepositVerificationTimeout =>
new BadRequest("Your bank account was not verified within the required time period. Please contact support for assistance."), new BadRequest(
"Your bank account was not verified within the required time period. Please contact support for assistance."),
ErrorCodes.TaxIdInvalid => ErrorCodes.TaxIdInvalid =>
new BadRequest("The tax ID number you provided was invalid. Please try again or contact support for assistance."), new BadRequest(
"The tax ID number you provided was invalid. Please try again or contact support for assistance."),
_ => new Unhandled(stripeException) _ => new Unhandled(stripeException)
}; };
} }
catch (ConflictException conflictException)
{
logger.LogError("{Command}: {Message}", CommandName, conflictException.Message);
return DefaultConflict != null ?
DefaultConflict :
new Unhandled(conflictException);
}
catch (StripeException stripeException) catch (StripeException stripeException)
{ {
logger.LogError(stripeException, logger.LogError(stripeException,

View File

@ -60,6 +60,7 @@ public static class StripeConstants
public const string InvoiceApproved = "invoice_approved"; public const string InvoiceApproved = "invoice_approved";
public const string OrganizationId = "organizationId"; public const string OrganizationId = "organizationId";
public const string ProviderId = "providerId"; public const string ProviderId = "providerId";
public const string Region = "region";
public const string RetiredBraintreeCustomerId = "btCustomerId_old"; public const string RetiredBraintreeCustomerId = "btCustomerId_old";
public const string UserId = "userId"; public const string UserId = "userId";
} }

View File

@ -1,5 +1,4 @@
#nullable enable using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Commands; using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Payment.Clients; using Bit.Core.Billing.Payment.Clients;
@ -21,8 +20,10 @@ public interface ICreateBitPayInvoiceForCreditCommand
public class CreateBitPayInvoiceForCreditCommand( public class CreateBitPayInvoiceForCreditCommand(
IBitPayClient bitPayClient, IBitPayClient bitPayClient,
GlobalSettings globalSettings, GlobalSettings globalSettings,
ILogger<CreateBitPayInvoiceForCreditCommand> logger) : BillingCommand<CreateBitPayInvoiceForCreditCommand>(logger), ICreateBitPayInvoiceForCreditCommand ILogger<CreateBitPayInvoiceForCreditCommand> logger) : BaseBillingCommand<CreateBitPayInvoiceForCreditCommand>(logger), ICreateBitPayInvoiceForCreditCommand
{ {
protected override Conflict DefaultConflict => new("We had a problem applying your account credit. Please contact support for assistance.");
public Task<BillingCommandResult<string>> Run( public Task<BillingCommandResult<string>> Run(
ISubscriber subscriber, ISubscriber subscriber,
decimal amount, decimal amount,

View File

@ -1,8 +1,8 @@
#nullable enable using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Services; using Bit.Core.Services;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
@ -19,14 +19,26 @@ public interface IUpdateBillingAddressCommand
public class UpdateBillingAddressCommand( public class UpdateBillingAddressCommand(
ILogger<UpdateBillingAddressCommand> logger, ILogger<UpdateBillingAddressCommand> logger,
IStripeAdapter stripeAdapter) : BillingCommand<UpdateBillingAddressCommand>(logger), IUpdateBillingAddressCommand ISubscriberService subscriberService,
IStripeAdapter stripeAdapter) : BaseBillingCommand<UpdateBillingAddressCommand>(logger), IUpdateBillingAddressCommand
{ {
protected override Conflict DefaultConflict =>
new("We had a problem updating your billing address. Please contact support for assistance.");
public Task<BillingCommandResult<BillingAddress>> Run( public Task<BillingCommandResult<BillingAddress>> Run(
ISubscriber subscriber, ISubscriber subscriber,
BillingAddress billingAddress) => HandleAsync(() => subscriber.GetProductUsageType() switch BillingAddress billingAddress) => HandleAsync(async () =>
{ {
ProductUsageType.Personal => UpdatePersonalBillingAddressAsync(subscriber, billingAddress), if (string.IsNullOrEmpty(subscriber.GatewayCustomerId))
ProductUsageType.Business => UpdateBusinessBillingAddressAsync(subscriber, billingAddress) {
await subscriberService.CreateStripeCustomer(subscriber);
}
return subscriber.GetProductUsageType() switch
{
ProductUsageType.Personal => await UpdatePersonalBillingAddressAsync(subscriber, billingAddress),
ProductUsageType.Business => await UpdateBusinessBillingAddressAsync(subscriber, billingAddress)
};
}); });
private async Task<BillingCommandResult<BillingAddress>> UpdatePersonalBillingAddressAsync( private async Task<BillingCommandResult<BillingAddress>> UpdatePersonalBillingAddressAsync(

View File

@ -1,5 +1,4 @@
#nullable enable using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Commands; using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Models;
@ -29,16 +28,22 @@ public class UpdatePaymentMethodCommand(
ILogger<UpdatePaymentMethodCommand> logger, ILogger<UpdatePaymentMethodCommand> logger,
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService) : BillingCommand<UpdatePaymentMethodCommand>(logger), IUpdatePaymentMethodCommand ISubscriberService subscriberService) : BaseBillingCommand<UpdatePaymentMethodCommand>(logger), IUpdatePaymentMethodCommand
{ {
private readonly ILogger<UpdatePaymentMethodCommand> _logger = logger; private readonly ILogger<UpdatePaymentMethodCommand> _logger = logger;
private static readonly Conflict _conflict = new("We had a problem updating your payment method. Please contact support for assistance."); protected override Conflict DefaultConflict
=> new("We had a problem updating your payment method. Please contact support for assistance.");
public Task<BillingCommandResult<MaskedPaymentMethod>> Run( public Task<BillingCommandResult<MaskedPaymentMethod>> Run(
ISubscriber subscriber, ISubscriber subscriber,
TokenizedPaymentMethod paymentMethod, TokenizedPaymentMethod paymentMethod,
BillingAddress? billingAddress) => HandleAsync(async () => BillingAddress? billingAddress) => HandleAsync(async () =>
{ {
if (string.IsNullOrEmpty(subscriber.GatewayCustomerId))
{
await subscriberService.CreateStripeCustomer(subscriber);
}
var customer = await subscriberService.GetCustomer(subscriber); var customer = await subscriberService.GetCustomer(subscriber);
var result = paymentMethod.Type switch var result = paymentMethod.Type switch
@ -80,10 +85,10 @@ public class UpdatePaymentMethodCommand(
{ {
case 0: case 0:
_logger.LogError("{Command}: Could not find setup intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id); _logger.LogError("{Command}: Could not find setup intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id);
return _conflict; return DefaultConflict;
case > 1: case > 1:
_logger.LogError("{Command}: Found more than one set up intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id); _logger.LogError("{Command}: Found more than one set up intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id);
return _conflict; return DefaultConflict;
} }
var setupIntent = setupIntents.First(); var setupIntent = setupIntents.First();

View File

@ -1,5 +1,4 @@
#nullable enable using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Commands; using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Models;
using Bit.Core.Entities; using Bit.Core.Entities;
@ -19,12 +18,12 @@ public interface IVerifyBankAccountCommand
public class VerifyBankAccountCommand( public class VerifyBankAccountCommand(
ILogger<VerifyBankAccountCommand> logger, ILogger<VerifyBankAccountCommand> logger,
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter) : BillingCommand<VerifyBankAccountCommand>(logger), IVerifyBankAccountCommand IStripeAdapter stripeAdapter) : BaseBillingCommand<VerifyBankAccountCommand>(logger), IVerifyBankAccountCommand
{ {
private readonly ILogger<VerifyBankAccountCommand> _logger = logger; private readonly ILogger<VerifyBankAccountCommand> _logger = logger;
private static readonly Conflict _conflict = protected override Conflict DefaultConflict
new("We had a problem verifying your bank account. Please contact support for assistance."); => new("We had a problem verifying your bank account. Please contact support for assistance.");
public Task<BillingCommandResult<MaskedPaymentMethod>> Run( public Task<BillingCommandResult<MaskedPaymentMethod>> Run(
ISubscriber subscriber, ISubscriber subscriber,
@ -37,7 +36,7 @@ public class VerifyBankAccountCommand(
_logger.LogError( _logger.LogError(
"{Command}: Could not find setup intent to verify subscriber's ({SubscriberID}) bank account", "{Command}: Could not find setup intent to verify subscriber's ({SubscriberID}) bank account",
CommandName, subscriber.Id); CommandName, subscriber.Id);
return _conflict; return DefaultConflict;
} }
await stripeAdapter.SetupIntentVerifyMicroDeposit(setupIntentId, await stripeAdapter.SetupIntentVerifyMicroDeposit(setupIntentId,

View File

@ -1,5 +1,4 @@
#nullable enable using Stripe;
using Stripe;
namespace Bit.Core.Billing.Payment.Models; namespace Bit.Core.Billing.Payment.Models;

View File

@ -1,7 +1,5 @@
#nullable enable using System.Text.Json;
using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using Bit.Core.Billing.Pricing.JSON;
using Braintree; using Braintree;
using OneOf; using OneOf;
using Stripe; using Stripe;
@ -83,32 +81,28 @@ public class MaskedPaymentMethod(OneOf<MaskedBankAccount, MaskedCard, MaskedPayP
public static MaskedPaymentMethod From(PayPalAccount payPalAccount) => new MaskedPayPalAccount { Email = payPalAccount.Email }; public static MaskedPaymentMethod From(PayPalAccount payPalAccount) => new MaskedPayPalAccount { Email = payPalAccount.Email };
} }
public class MaskedPaymentMethodJsonConverter : TypeReadingJsonConverter<MaskedPaymentMethod> public class MaskedPaymentMethodJsonConverter : JsonConverter<MaskedPaymentMethod>
{ {
protected override string TypePropertyName => nameof(MaskedBankAccount.Type).ToLower(); private const string _typePropertyName = nameof(MaskedBankAccount.Type);
public override MaskedPaymentMethod? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) public override MaskedPaymentMethod Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{ {
var type = ReadType(reader); var element = JsonElement.ParseValue(ref reader);
if (!element.TryGetProperty(options.PropertyNamingPolicy?.ConvertName(_typePropertyName) ?? _typePropertyName, out var typeProperty))
{
throw new JsonException(
$"Failed to deserialize {nameof(MaskedPaymentMethod)}: missing '{_typePropertyName}' property");
}
var type = typeProperty.GetString();
return type switch return type switch
{ {
"bankAccount" => JsonSerializer.Deserialize<MaskedBankAccount>(ref reader, options) switch "bankAccount" => element.Deserialize<MaskedBankAccount>(options)!,
{ "card" => element.Deserialize<MaskedCard>(options)!,
null => null, "payPal" => element.Deserialize<MaskedPayPalAccount>(options)!,
var bankAccount => new MaskedPaymentMethod(bankAccount) _ => throw new JsonException($"Failed to deserialize {nameof(MaskedPaymentMethod)}: invalid '{_typePropertyName}' value - '{type}'")
},
"card" => JsonSerializer.Deserialize<MaskedCard>(ref reader, options) switch
{
null => null,
var card => new MaskedPaymentMethod(card)
},
"payPal" => JsonSerializer.Deserialize<MaskedPayPalAccount>(ref reader, options) switch
{
null => null,
var payPal => new MaskedPaymentMethod(payPal)
},
_ => Skip(ref reader)
}; };
} }

View File

@ -1,5 +1,4 @@
#nullable enable namespace Bit.Core.Billing.Payment.Models;
namespace Bit.Core.Billing.Payment.Models;
public record TokenizedPaymentMethod public record TokenizedPaymentMethod
{ {

View File

@ -1,5 +1,4 @@
#nullable enable using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Entities; using Bit.Core.Entities;

View File

@ -1,5 +1,4 @@
#nullable enable using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services;
using Bit.Core.Entities; using Bit.Core.Entities;
namespace Bit.Core.Billing.Payment.Queries; namespace Bit.Core.Billing.Payment.Queries;

View File

@ -1,5 +1,4 @@
#nullable enable using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Models;
@ -29,6 +28,11 @@ public class GetPaymentMethodQuery(
var customer = await subscriberService.GetCustomer(subscriber, var customer = await subscriberService.GetCustomer(subscriber,
new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] }); new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] });
if (customer == null)
{
return null;
}
if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId)) if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId))
{ {
var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId);

View File

@ -1,35 +0,0 @@
using System.Text.Json;
using Bit.Core.Billing.Pricing.Models;
namespace Bit.Core.Billing.Pricing.JSON;
#nullable enable
public class FreeOrScalableDTOJsonConverter : TypeReadingJsonConverter<FreeOrScalableDTO>
{
public override FreeOrScalableDTO? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var type = ReadType(reader);
return type switch
{
"free" => JsonSerializer.Deserialize<FreeDTO>(ref reader, options) switch
{
null => null,
var free => new FreeOrScalableDTO(free)
},
"scalable" => JsonSerializer.Deserialize<ScalableDTO>(ref reader, options) switch
{
null => null,
var scalable => new FreeOrScalableDTO(scalable)
},
_ => null
};
}
public override void Write(Utf8JsonWriter writer, FreeOrScalableDTO value, JsonSerializerOptions options)
=> value.Switch(
free => JsonSerializer.Serialize(writer, free, options),
scalable => JsonSerializer.Serialize(writer, scalable, options)
);
}

View File

@ -1,40 +0,0 @@
using System.Text.Json;
using Bit.Core.Billing.Pricing.Models;
namespace Bit.Core.Billing.Pricing.JSON;
#nullable enable
internal class PurchasableDTOJsonConverter : TypeReadingJsonConverter<PurchasableDTO>
{
public override PurchasableDTO? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var type = ReadType(reader);
return type switch
{
"free" => JsonSerializer.Deserialize<FreeDTO>(ref reader, options) switch
{
null => null,
var free => new PurchasableDTO(free)
},
"packaged" => JsonSerializer.Deserialize<PackagedDTO>(ref reader, options) switch
{
null => null,
var packaged => new PurchasableDTO(packaged)
},
"scalable" => JsonSerializer.Deserialize<ScalableDTO>(ref reader, options) switch
{
null => null,
var scalable => new PurchasableDTO(scalable)
},
_ => null
};
}
public override void Write(Utf8JsonWriter writer, PurchasableDTO value, JsonSerializerOptions options)
=> value.Switch(
free => JsonSerializer.Serialize(writer, free, options),
packaged => JsonSerializer.Serialize(writer, packaged, options),
scalable => JsonSerializer.Serialize(writer, scalable, options)
);
}

View File

@ -1,36 +0,0 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using Bit.Core.Billing.Pricing.Models;
namespace Bit.Core.Billing.Pricing.JSON;
#nullable enable
public abstract class TypeReadingJsonConverter<T> : JsonConverter<T> where T : class
{
protected virtual string TypePropertyName => nameof(ScalableDTO.Type).ToLower();
protected string? ReadType(Utf8JsonReader reader)
{
while (reader.Read())
{
if (reader.CurrentDepth != 1 ||
reader.TokenType != JsonTokenType.PropertyName ||
reader.GetString()?.ToLower() != TypePropertyName)
{
continue;
}
reader.Read();
return reader.GetString();
}
return null;
}
protected T? Skip(ref Utf8JsonReader reader)
{
reader.Skip();
return null;
}
}

View File

@ -0,0 +1,7 @@
namespace Bit.Core.Billing.Pricing.Models;
public class Feature
{
public required string Name { get; set; }
public required string LookupKey { get; set; }
}

View File

@ -1,9 +0,0 @@
namespace Bit.Core.Billing.Pricing.Models;
#nullable enable
public class FeatureDTO
{
public string Name { get; set; } = null!;
public string LookupKey { get; set; } = null!;
}

View File

@ -0,0 +1,25 @@
namespace Bit.Core.Billing.Pricing.Models;
public class Plan
{
public required string LookupKey { get; set; }
public required string Name { get; set; }
public required string Tier { get; set; }
public string? Cadence { get; set; }
public int? LegacyYear { get; set; }
public bool Available { get; set; }
public required Feature[] Features { get; set; }
public required Purchasable Seats { get; set; }
public Scalable? ManagedSeats { get; set; }
public Scalable? Storage { get; set; }
public SecretsManagerPurchasables? SecretsManager { get; set; }
public int? TrialPeriodDays { get; set; }
public required string[] CanUpgradeTo { get; set; }
public required Dictionary<string, string> AdditionalData { get; set; }
}
public class SecretsManagerPurchasables
{
public required FreeOrScalable Seats { get; set; }
public required FreeOrScalable ServiceAccounts { get; set; }
}

View File

@ -1,27 +0,0 @@
namespace Bit.Core.Billing.Pricing.Models;
#nullable enable
public class PlanDTO
{
public string LookupKey { get; set; } = null!;
public string Name { get; set; } = null!;
public string Tier { get; set; } = null!;
public string? Cadence { get; set; }
public int? LegacyYear { get; set; }
public bool Available { get; set; }
public FeatureDTO[] Features { get; set; } = null!;
public PurchasableDTO Seats { get; set; } = null!;
public ScalableDTO? ManagedSeats { get; set; }
public ScalableDTO? Storage { get; set; }
public SecretsManagerPurchasablesDTO? SecretsManager { get; set; }
public int? TrialPeriodDays { get; set; }
public string[] CanUpgradeTo { get; set; } = null!;
public Dictionary<string, string> AdditionalData { get; set; } = null!;
}
public class SecretsManagerPurchasablesDTO
{
public FreeOrScalableDTO Seats { get; set; } = null!;
public FreeOrScalableDTO ServiceAccounts { get; set; } = null!;
}

View File

@ -0,0 +1,135 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using OneOf;
namespace Bit.Core.Billing.Pricing.Models;
[JsonConverter(typeof(PurchasableJsonConverter))]
public class Purchasable(OneOf<Free, Packaged, Scalable> input) : OneOfBase<Free, Packaged, Scalable>(input)
{
public static implicit operator Purchasable(Free free) => new(free);
public static implicit operator Purchasable(Packaged packaged) => new(packaged);
public static implicit operator Purchasable(Scalable scalable) => new(scalable);
public T? FromFree<T>(Func<Free, T> select, Func<Purchasable, T>? fallback = null) =>
IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default;
public T? FromPackaged<T>(Func<Packaged, T> select, Func<Purchasable, T>? fallback = null) =>
IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default;
public T? FromScalable<T>(Func<Scalable, T> select, Func<Purchasable, T>? fallback = null) =>
IsT2 ? select(AsT2) : fallback != null ? fallback(this) : default;
public bool IsFree => IsT0;
public bool IsPackaged => IsT1;
public bool IsScalable => IsT2;
}
internal class PurchasableJsonConverter : JsonConverter<Purchasable>
{
private static readonly string _typePropertyName = nameof(Free.Type).ToLower();
public override Purchasable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var element = JsonElement.ParseValue(ref reader);
if (!element.TryGetProperty(options.PropertyNamingPolicy?.ConvertName(_typePropertyName) ?? _typePropertyName, out var typeProperty))
{
throw new JsonException(
$"Failed to deserialize {nameof(Purchasable)}: missing '{_typePropertyName}' property");
}
var type = typeProperty.GetString();
return type switch
{
"free" => element.Deserialize<Free>(options)!,
"packaged" => element.Deserialize<Packaged>(options)!,
"scalable" => element.Deserialize<Scalable>(options)!,
_ => throw new JsonException($"Failed to deserialize {nameof(Purchasable)}: invalid '{_typePropertyName}' value - '{type}'"),
};
}
public override void Write(Utf8JsonWriter writer, Purchasable value, JsonSerializerOptions options)
=> value.Switch(
free => JsonSerializer.Serialize(writer, free, options),
packaged => JsonSerializer.Serialize(writer, packaged, options),
scalable => JsonSerializer.Serialize(writer, scalable, options)
);
}
[JsonConverter(typeof(FreeOrScalableJsonConverter))]
public class FreeOrScalable(OneOf<Free, Scalable> input) : OneOfBase<Free, Scalable>(input)
{
public static implicit operator FreeOrScalable(Free free) => new(free);
public static implicit operator FreeOrScalable(Scalable scalable) => new(scalable);
public T? FromFree<T>(Func<Free, T> select, Func<FreeOrScalable, T>? fallback = null) =>
IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default;
public T? FromScalable<T>(Func<Scalable, T> select, Func<FreeOrScalable, T>? fallback = null) =>
IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default;
public bool IsFree => IsT0;
public bool IsScalable => IsT1;
}
public class FreeOrScalableJsonConverter : JsonConverter<FreeOrScalable>
{
private static readonly string _typePropertyName = nameof(Free.Type).ToLower();
public override FreeOrScalable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var element = JsonElement.ParseValue(ref reader);
if (!element.TryGetProperty(options.PropertyNamingPolicy?.ConvertName(_typePropertyName) ?? _typePropertyName, out var typeProperty))
{
throw new JsonException(
$"Failed to deserialize {nameof(FreeOrScalable)}: missing '{_typePropertyName}' property");
}
var type = typeProperty.GetString();
return type switch
{
"free" => element.Deserialize<Free>(options)!,
"scalable" => element.Deserialize<Scalable>(options)!,
_ => throw new JsonException($"Failed to deserialize {nameof(FreeOrScalable)}: invalid '{_typePropertyName}' value - '{type}'"),
};
}
public override void Write(Utf8JsonWriter writer, FreeOrScalable value, JsonSerializerOptions options)
=> value.Switch(
free => JsonSerializer.Serialize(writer, free, options),
scalable => JsonSerializer.Serialize(writer, scalable, options)
);
}
public class Free
{
public int Quantity { get; set; }
public string Type => "free";
}
public class Packaged
{
public int Quantity { get; set; }
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
public AdditionalSeats? Additional { get; set; }
public string Type => "packaged";
public class AdditionalSeats
{
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
}
}
public class Scalable
{
public int Provided { get; set; }
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
public string Type => "scalable";
}

View File

@ -1,73 +0,0 @@
using System.Text.Json.Serialization;
using Bit.Core.Billing.Pricing.JSON;
using OneOf;
namespace Bit.Core.Billing.Pricing.Models;
#nullable enable
[JsonConverter(typeof(PurchasableDTOJsonConverter))]
public class PurchasableDTO(OneOf<FreeDTO, PackagedDTO, ScalableDTO> input) : OneOfBase<FreeDTO, PackagedDTO, ScalableDTO>(input)
{
public static implicit operator PurchasableDTO(FreeDTO free) => new(free);
public static implicit operator PurchasableDTO(PackagedDTO packaged) => new(packaged);
public static implicit operator PurchasableDTO(ScalableDTO scalable) => new(scalable);
public T? FromFree<T>(Func<FreeDTO, T> select, Func<PurchasableDTO, T>? fallback = null) =>
IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default;
public T? FromPackaged<T>(Func<PackagedDTO, T> select, Func<PurchasableDTO, T>? fallback = null) =>
IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default;
public T? FromScalable<T>(Func<ScalableDTO, T> select, Func<PurchasableDTO, T>? fallback = null) =>
IsT2 ? select(AsT2) : fallback != null ? fallback(this) : default;
public bool IsFree => IsT0;
public bool IsPackaged => IsT1;
public bool IsScalable => IsT2;
}
[JsonConverter(typeof(FreeOrScalableDTOJsonConverter))]
public class FreeOrScalableDTO(OneOf<FreeDTO, ScalableDTO> input) : OneOfBase<FreeDTO, ScalableDTO>(input)
{
public static implicit operator FreeOrScalableDTO(FreeDTO freeDTO) => new(freeDTO);
public static implicit operator FreeOrScalableDTO(ScalableDTO scalableDTO) => new(scalableDTO);
public T? FromFree<T>(Func<FreeDTO, T> select, Func<FreeOrScalableDTO, T>? fallback = null) =>
IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default;
public T? FromScalable<T>(Func<ScalableDTO, T> select, Func<FreeOrScalableDTO, T>? fallback = null) =>
IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default;
public bool IsFree => IsT0;
public bool IsScalable => IsT1;
}
public class FreeDTO
{
public int Quantity { get; set; }
public string Type => "free";
}
public class PackagedDTO
{
public int Quantity { get; set; }
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
public AdditionalSeats? Additional { get; set; }
public string Type => "packaged";
public class AdditionalSeats
{
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
}
}
public class ScalableDTO
{
public int Provided { get; set; }
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
public string Type => "scalable";
}

View File

@ -1,14 +1,12 @@
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing.Models; using Bit.Core.Billing.Pricing.Models;
using Bit.Core.Models.StaticStore; using Plan = Bit.Core.Billing.Pricing.Models.Plan;
#nullable enable
namespace Bit.Core.Billing.Pricing; namespace Bit.Core.Billing.Pricing;
public record PlanAdapter : Plan public record PlanAdapter : Core.Models.StaticStore.Plan
{ {
public PlanAdapter(PlanDTO plan) public PlanAdapter(Plan plan)
{ {
Type = ToPlanType(plan.LookupKey); Type = ToPlanType(plan.LookupKey);
ProductTier = ToProductTierType(Type); ProductTier = ToProductTierType(Type);
@ -88,7 +86,7 @@ public record PlanAdapter : Plan
_ => throw new BillingException() // TODO: Flesh out _ => throw new BillingException() // TODO: Flesh out
}; };
private static PasswordManagerPlanFeatures ToPasswordManagerPlanFeatures(PlanDTO plan) private static PasswordManagerPlanFeatures ToPasswordManagerPlanFeatures(Plan plan)
{ {
var stripePlanId = GetStripePlanId(plan.Seats); var stripePlanId = GetStripePlanId(plan.Seats);
var stripeSeatPlanId = GetStripeSeatPlanId(plan.Seats); var stripeSeatPlanId = GetStripeSeatPlanId(plan.Seats);
@ -128,7 +126,7 @@ public record PlanAdapter : Plan
}; };
} }
private static SecretsManagerPlanFeatures ToSecretsManagerPlanFeatures(PlanDTO plan) private static SecretsManagerPlanFeatures ToSecretsManagerPlanFeatures(Plan plan)
{ {
var seats = plan.SecretsManager!.Seats; var seats = plan.SecretsManager!.Seats;
var serviceAccounts = plan.SecretsManager.ServiceAccounts; var serviceAccounts = plan.SecretsManager.ServiceAccounts;
@ -165,62 +163,62 @@ public record PlanAdapter : Plan
}; };
} }
private static decimal? GetAdditionalPricePerServiceAccount(FreeOrScalableDTO freeOrScalable) private static decimal? GetAdditionalPricePerServiceAccount(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromScalable(x => x.Price); => freeOrScalable.FromScalable(x => x.Price);
private static decimal GetBasePrice(PurchasableDTO purchasable) private static decimal GetBasePrice(Purchasable purchasable)
=> purchasable.FromPackaged(x => x.Price); => purchasable.FromPackaged(x => x.Price);
private static int GetBaseSeats(FreeOrScalableDTO freeOrScalable) private static int GetBaseSeats(FreeOrScalable freeOrScalable)
=> freeOrScalable.Match( => freeOrScalable.Match(
free => free.Quantity, free => free.Quantity,
scalable => scalable.Provided); scalable => scalable.Provided);
private static int GetBaseSeats(PurchasableDTO purchasable) private static int GetBaseSeats(Purchasable purchasable)
=> purchasable.Match( => purchasable.Match(
free => free.Quantity, free => free.Quantity,
packaged => packaged.Quantity, packaged => packaged.Quantity,
scalable => scalable.Provided); scalable => scalable.Provided);
private static short GetBaseServiceAccount(FreeOrScalableDTO freeOrScalable) private static short GetBaseServiceAccount(FreeOrScalable freeOrScalable)
=> freeOrScalable.Match( => freeOrScalable.Match(
free => (short)free.Quantity, free => (short)free.Quantity,
scalable => (short)scalable.Provided); scalable => (short)scalable.Provided);
private static short? GetMaxSeats(PurchasableDTO purchasable) private static short? GetMaxSeats(Purchasable purchasable)
=> purchasable.Match<short?>( => purchasable.Match<short?>(
free => (short)free.Quantity, free => (short)free.Quantity,
packaged => (short)packaged.Quantity, packaged => (short)packaged.Quantity,
_ => null); _ => null);
private static short? GetMaxSeats(FreeOrScalableDTO freeOrScalable) private static short? GetMaxSeats(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromFree(x => (short)x.Quantity); => freeOrScalable.FromFree(x => (short)x.Quantity);
private static short? GetMaxServiceAccounts(FreeOrScalableDTO freeOrScalable) private static short? GetMaxServiceAccounts(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromFree(x => (short)x.Quantity); => freeOrScalable.FromFree(x => (short)x.Quantity);
private static decimal GetSeatPrice(PurchasableDTO purchasable) private static decimal GetSeatPrice(Purchasable purchasable)
=> purchasable.Match( => purchasable.Match(
_ => 0, _ => 0,
packaged => packaged.Additional?.Price ?? 0, packaged => packaged.Additional?.Price ?? 0,
scalable => scalable.Price); scalable => scalable.Price);
private static decimal GetSeatPrice(FreeOrScalableDTO freeOrScalable) private static decimal GetSeatPrice(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromScalable(x => x.Price); => freeOrScalable.FromScalable(x => x.Price);
private static string? GetStripePlanId(PurchasableDTO purchasable) private static string? GetStripePlanId(Purchasable purchasable)
=> purchasable.FromPackaged(x => x.StripePriceId); => purchasable.FromPackaged(x => x.StripePriceId);
private static string? GetStripeSeatPlanId(PurchasableDTO purchasable) private static string? GetStripeSeatPlanId(Purchasable purchasable)
=> purchasable.Match( => purchasable.Match(
_ => null, _ => null,
packaged => packaged.Additional?.StripePriceId, packaged => packaged.Additional?.StripePriceId,
scalable => scalable.StripePriceId); scalable => scalable.StripePriceId);
private static string? GetStripeSeatPlanId(FreeOrScalableDTO freeOrScalable) private static string? GetStripeSeatPlanId(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromScalable(x => x.StripePriceId); => freeOrScalable.FromScalable(x => x.StripePriceId);
private static string? GetStripeServiceAccountPlanId(FreeOrScalableDTO freeOrScalable) private static string? GetStripeServiceAccountPlanId(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromScalable(x => x.StripePriceId); => freeOrScalable.FromScalable(x => x.StripePriceId);
#endregion #endregion

View File

@ -1,7 +1,6 @@
using System.Net; using System.Net;
using System.Net.Http.Json; using System.Net.Http.Json;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing.Models;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
@ -45,7 +44,7 @@ public class PricingClient(
if (response.IsSuccessStatusCode) if (response.IsSuccessStatusCode)
{ {
var plan = await response.Content.ReadFromJsonAsync<PlanDTO>(); var plan = await response.Content.ReadFromJsonAsync<Models.Plan>();
if (plan == null) if (plan == null)
{ {
throw new BillingException(message: "Deserialization of Pricing Service response resulted in null"); throw new BillingException(message: "Deserialization of Pricing Service response resulted in null");
@ -93,7 +92,7 @@ public class PricingClient(
if (response.IsSuccessStatusCode) if (response.IsSuccessStatusCode)
{ {
var plans = await response.Content.ReadFromJsonAsync<List<PlanDTO>>(); var plans = await response.Content.ReadFromJsonAsync<List<Models.Plan>>();
if (plans == null) if (plans == null)
{ {
throw new BillingException(message: "Deserialization of Pricing Service response resulted in null"); throw new BillingException(message: "Deserialization of Pricing Service response resulted in null");

View File

@ -36,6 +36,9 @@ public interface ISubscriberService
ISubscriber subscriber, ISubscriber subscriber,
string paymentMethodNonce); string paymentMethodNonce);
Task<Customer> CreateStripeCustomer(
ISubscriber subscriber);
/// <summary> /// <summary>
/// Retrieves a Stripe <see cref="Customer"/> using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewayCustomerId"/> property. /// Retrieves a Stripe <see cref="Customer"/> using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewayCustomerId"/> property.
/// </summary> /// </summary>

View File

@ -3,6 +3,7 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Caches; using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
@ -13,6 +14,7 @@ using Bit.Core.Billing.Tax.Services;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
@ -27,14 +29,19 @@ using Subscription = Stripe.Subscription;
namespace Bit.Core.Billing.Services.Implementations; namespace Bit.Core.Billing.Services.Implementations;
using static StripeConstants;
public class SubscriberService( public class SubscriberService(
IBraintreeGateway braintreeGateway, IBraintreeGateway braintreeGateway,
IFeatureService featureService, IFeatureService featureService,
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
ILogger<SubscriberService> logger, ILogger<SubscriberService> logger,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ITaxService taxService) : ISubscriberService ITaxService taxService,
IUserRepository userRepository) : ISubscriberService
{ {
public async Task CancelSubscription( public async Task CancelSubscription(
ISubscriber subscriber, ISubscriber subscriber,
@ -146,6 +153,110 @@ public class SubscriberService(
throw new BillingException(); throw new BillingException();
} }
#nullable enable
public async Task<Customer> CreateStripeCustomer(ISubscriber subscriber)
{
if (!string.IsNullOrEmpty(subscriber.GatewayCustomerId))
{
throw new ConflictException("Subscriber already has a linked Stripe Customer");
}
var options = subscriber switch
{
Organization organization => new CustomerCreateOptions
{
Description = organization.DisplayBusinessName(),
Email = organization.BillingEmail,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = organization.SubscriberType(),
Value = Max30Characters(organization.DisplayName())
}
]
},
Metadata = new Dictionary<string, string>
{
[MetadataKeys.OrganizationId] = organization.Id.ToString(),
[MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion
}
},
Provider provider => new CustomerCreateOptions
{
Description = provider.DisplayBusinessName(),
Email = provider.BillingEmail,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = provider.SubscriberType(),
Value = Max30Characters(provider.DisplayName())
}
]
},
Metadata = new Dictionary<string, string>
{
[MetadataKeys.ProviderId] = provider.Id.ToString(),
[MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion
}
},
User user => new CustomerCreateOptions
{
Description = user.Name,
Email = user.Email,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = user.SubscriberType(),
Value = Max30Characters(user.SubscriberName())
}
]
},
Metadata = new Dictionary<string, string>
{
[MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion,
[MetadataKeys.UserId] = user.Id.ToString()
}
},
_ => throw new ArgumentOutOfRangeException(nameof(subscriber))
};
var customer = await stripeAdapter.CustomerCreateAsync(options);
switch (subscriber)
{
case Organization organization:
organization.Gateway = GatewayType.Stripe;
organization.GatewayCustomerId = customer.Id;
await organizationRepository.ReplaceAsync(organization);
break;
case Provider provider:
provider.Gateway = GatewayType.Stripe;
provider.GatewayCustomerId = customer.Id;
await providerRepository.ReplaceAsync(provider);
break;
case User user:
user.Gateway = GatewayType.Stripe;
user.GatewayCustomerId = customer.Id;
await userRepository.ReplaceAsync(user);
break;
}
return customer;
string? Max30Characters(string? input)
=> input?.Length <= 30 ? input : input?[..30];
}
#nullable disable
public async Task<Customer> GetCustomer( public async Task<Customer> GetCustomer(
ISubscriber subscriber, ISubscriber subscriber,
CustomerGetOptions customerGetOptions = null) CustomerGetOptions customerGetOptions = null)

View File

@ -1,5 +1,4 @@
#nullable enable using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
@ -20,8 +19,11 @@ public class PreviewTaxAmountCommand(
ILogger<PreviewTaxAmountCommand> logger, ILogger<PreviewTaxAmountCommand> logger,
IPricingClient pricingClient, IPricingClient pricingClient,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ITaxService taxService) : BillingCommand<PreviewTaxAmountCommand>(logger), IPreviewTaxAmountCommand ITaxService taxService) : BaseBillingCommand<PreviewTaxAmountCommand>(logger), IPreviewTaxAmountCommand
{ {
protected override Conflict DefaultConflict
=> new("We had a problem calculating your tax obligation. Please contact support for assistance.");
public Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters) public Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters)
=> HandleAsync<decimal>(async () => => HandleAsync<decimal>(async () =>
{ {

View File

@ -3,6 +3,7 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Commands;
using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Test.Billing.Extensions; using Bit.Core.Test.Billing.Extensions;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
@ -16,14 +17,15 @@ using static StripeConstants;
public class UpdateBillingAddressCommandTests public class UpdateBillingAddressCommandTests
{ {
private readonly IStripeAdapter _stripeAdapter; private readonly ISubscriberService _subscriberService = Substitute.For<ISubscriberService>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly UpdateBillingAddressCommand _command; private readonly UpdateBillingAddressCommand _command;
public UpdateBillingAddressCommandTests() public UpdateBillingAddressCommandTests()
{ {
_stripeAdapter = Substitute.For<IStripeAdapter>();
_command = new UpdateBillingAddressCommand( _command = new UpdateBillingAddressCommand(
Substitute.For<ILogger<UpdateBillingAddressCommand>>(), Substitute.For<ILogger<UpdateBillingAddressCommand>>(),
_subscriberService,
_stripeAdapter); _stripeAdapter);
} }
@ -86,6 +88,66 @@ public class UpdateBillingAddressCommandTests
Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true)); Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true));
} }
[Fact]
public async Task Run_PersonalOrganization_NoCurrentCustomer_MakesCorrectInvocations_ReturnsBillingAddress()
{
var organization = new Organization
{
PlanType = PlanType.FamiliesAnnually,
GatewaySubscriptionId = "sub_123"
};
var input = new BillingAddress
{
Country = "US",
PostalCode = "12345",
Line1 = "123 Main St.",
Line2 = "Suite 100",
City = "New York",
State = "NY"
};
var customer = new Customer
{
Address = new Address
{
Country = "US",
PostalCode = "12345",
Line1 = "123 Main St.",
Line2 = "Suite 100",
City = "New York",
State = "NY"
},
Subscriptions = new StripeList<Subscription>
{
Data =
[
new Subscription
{
Id = organization.GatewaySubscriptionId,
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
}
]
}
};
_stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(options =>
options.Address.Matches(input) &&
options.HasExpansions("subscriptions")
)).Returns(customer);
var result = await _command.Run(organization, input);
Assert.True(result.IsT0);
var output = result.AsT0;
Assert.Equivalent(input, output);
await _subscriberService.Received(1).CreateStripeCustomer(organization);
await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true));
}
[Fact] [Fact]
public async Task Run_BusinessOrganization_MakesCorrectInvocations_ReturnsBillingAddress() public async Task Run_BusinessOrganization_MakesCorrectInvocations_ReturnsBillingAddress()
{ {

View File

@ -45,7 +45,8 @@ public class UpdatePaymentMethodCommandTests
{ {
var organization = new Organization var organization = new Organization
{ {
Id = Guid.NewGuid() Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
}; };
var customer = new Customer var customer = new Customer
@ -100,13 +101,75 @@ public class UpdatePaymentMethodCommandTests
} }
[Fact] [Fact]
public async Task Run_BankAccount_StripeToPayPal_MakesCorrectInvocations_ReturnsMaskedBankAccount() public async Task Run_BankAccount_NoCurrentCustomer_MakesCorrectInvocations_ReturnsMaskedBankAccount()
{ {
var organization = new Organization var organization = new Organization
{ {
Id = Guid.NewGuid() Id = Guid.NewGuid()
}; };
var customer = new Customer
{
Address = new Address
{
Country = "US",
PostalCode = "12345"
},
Metadata = new Dictionary<string, string>()
};
_subscriberService.GetCustomer(organization).Returns(customer);
const string token = "TOKEN";
var setupIntent = new SetupIntent
{
Id = "seti_123",
PaymentMethod =
new PaymentMethod
{
Type = "us_bank_account",
UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" }
},
NextAction = new SetupIntentNextAction
{
VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits()
},
Status = "requires_action"
};
_stripeAdapter.SetupIntentList(Arg.Is<SetupIntentListOptions>(options =>
options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]);
var result = await _command.Run(organization,
new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = token }, new BillingAddress
{
Country = "US",
PostalCode = "12345"
});
Assert.True(result.IsT0);
var maskedPaymentMethod = result.AsT0;
Assert.True(maskedPaymentMethod.IsT0);
var maskedBankAccount = maskedPaymentMethod.AsT0;
Assert.Equal("Chase", maskedBankAccount.BankName);
Assert.Equal("9999", maskedBankAccount.Last4);
Assert.False(maskedBankAccount.Verified);
await _subscriberService.Received(1).CreateStripeCustomer(organization);
await _setupIntentCache.Received(1).Set(organization.Id, setupIntent.Id);
}
[Fact]
public async Task Run_BankAccount_StripeToPayPal_MakesCorrectInvocations_ReturnsMaskedBankAccount()
{
var organization = new Organization
{
Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
};
var customer = new Customer var customer = new Customer
{ {
Address = new Address Address = new Address
@ -170,7 +233,8 @@ public class UpdatePaymentMethodCommandTests
{ {
var organization = new Organization var organization = new Organization
{ {
Id = Guid.NewGuid() Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
}; };
var customer = new Customer var customer = new Customer
@ -227,7 +291,8 @@ public class UpdatePaymentMethodCommandTests
{ {
var organization = new Organization var organization = new Organization
{ {
Id = Guid.NewGuid() Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
}; };
var customer = new Customer var customer = new Customer
@ -282,7 +347,8 @@ public class UpdatePaymentMethodCommandTests
{ {
var organization = new Organization var organization = new Organization
{ {
Id = Guid.NewGuid() Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
}; };
var customer = new Customer var customer = new Customer
@ -343,7 +409,8 @@ public class UpdatePaymentMethodCommandTests
{ {
var organization = new Organization var organization = new Organization
{ {
Id = Guid.NewGuid() Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
}; };
var customer = new Customer var customer = new Customer

View File

@ -25,6 +25,27 @@ public class MaskedPaymentMethodTests
Assert.Equivalent(input.AsT0, output.AsT0); Assert.Equivalent(input.AsT0, output.AsT0);
} }
[Fact]
public void Write_Read_BankAccount_WithOptions_Succeeds()
{
MaskedPaymentMethod input = new MaskedBankAccount
{
BankName = "Chase",
Last4 = "9999",
Verified = true
};
var jsonSerializerOptions = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase };
var json = JsonSerializer.Serialize(input, jsonSerializerOptions);
var output = JsonSerializer.Deserialize<MaskedPaymentMethod>(json, jsonSerializerOptions);
Assert.NotNull(output);
Assert.True(output.IsT0);
Assert.Equivalent(input.AsT0, output.AsT0);
}
[Fact] [Fact]
public void Write_Read_Card_Succeeds() public void Write_Read_Card_Succeeds()
{ {

View File

@ -8,6 +8,7 @@ using Bit.Core.Test.Billing.Extensions;
using Braintree; using Braintree;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using NSubstitute; using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Stripe; using Stripe;
using Xunit; using Xunit;
using Customer = Stripe.Customer; using Customer = Stripe.Customer;
@ -35,6 +36,23 @@ public class GetPaymentMethodQueryTests
_subscriberService); _subscriberService);
} }
[Fact]
public async Task Run_NoCustomer_ReturnsNull()
{
var organization = new Organization
{
Id = Guid.NewGuid()
};
_subscriberService.GetCustomer(organization,
Arg.Is<CustomerGetOptions>(options =>
options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).ReturnsNull();
var maskedPaymentMethod = await _query.Run(organization);
Assert.Null(maskedPaymentMethod);
}
[Fact] [Fact]
public async Task Run_NoPaymentMethod_ReturnsNull() public async Task Run_NoPaymentMethod_ReturnsNull()
{ {