1
0
mirror of https://github.com/bitwarden/server.git synced 2025-07-18 16:11:28 -05:00

[PM-23687] Support free organizations on Payment Details page (#6084)

* Resolve JSON serialization bug in OneOf converters and organize pricing models

* Support free organizations for payment method and billing address flows

* Run dotnet format
This commit is contained in:
Alex Morask
2025-07-14 12:39:49 -05:00
committed by GitHub
parent 0e4e060f22
commit d914ab8a98
30 changed files with 575 additions and 316 deletions

View File

@ -1,4 +1,5 @@
using Bit.Core.Billing.Constants;
using Bit.Core.Exceptions;
using Microsoft.Extensions.Logging;
using Stripe;
@ -6,11 +7,17 @@ namespace Bit.Core.Billing.Commands;
using static StripeConstants;
public abstract class BillingCommand<T>(
public abstract class BaseBillingCommand<T>(
ILogger<T> logger)
{
protected string CommandName => GetType().Name;
/// <summary>
/// Override this property to set a client-facing conflict response in the case a <see cref="ConflictException"/> is thrown
/// during the command's execution.
/// </summary>
protected virtual Conflict? DefaultConflict => null;
/// <summary>
/// Executes the provided function within a predefined execution context, handling any exceptions that occur during the process.
/// </summary>
@ -29,23 +36,35 @@ public abstract class BillingCommand<T>(
return stripeException.StripeError.Code switch
{
ErrorCodes.CustomerTaxLocationInvalid =>
new BadRequest("Your location wasn't recognized. Please ensure your country and postal code are valid and try again."),
new BadRequest(
"Your location wasn't recognized. Please ensure your country and postal code are valid and try again."),
ErrorCodes.PaymentMethodMicroDepositVerificationAttemptsExceeded =>
new BadRequest("You have exceeded the number of allowed verification attempts. Please contact support for assistance."),
new BadRequest(
"You have exceeded the number of allowed verification attempts. Please contact support for assistance."),
ErrorCodes.PaymentMethodMicroDepositVerificationDescriptorCodeMismatch =>
new BadRequest("The verification code you provided does not match the one sent to your bank account. Please try again."),
new BadRequest(
"The verification code you provided does not match the one sent to your bank account. Please try again."),
ErrorCodes.PaymentMethodMicroDepositVerificationTimeout =>
new BadRequest("Your bank account was not verified within the required time period. Please contact support for assistance."),
new BadRequest(
"Your bank account was not verified within the required time period. Please contact support for assistance."),
ErrorCodes.TaxIdInvalid =>
new BadRequest("The tax ID number you provided was invalid. Please try again or contact support for assistance."),
new BadRequest(
"The tax ID number you provided was invalid. Please try again or contact support for assistance."),
_ => new Unhandled(stripeException)
};
}
catch (ConflictException conflictException)
{
logger.LogError("{Command}: {Message}", CommandName, conflictException.Message);
return DefaultConflict != null ?
DefaultConflict :
new Unhandled(conflictException);
}
catch (StripeException stripeException)
{
logger.LogError(stripeException,

View File

@ -60,6 +60,7 @@ public static class StripeConstants
public const string InvoiceApproved = "invoice_approved";
public const string OrganizationId = "organizationId";
public const string ProviderId = "providerId";
public const string Region = "region";
public const string RetiredBraintreeCustomerId = "btCustomerId_old";
public const string UserId = "userId";
}

View File

@ -1,5 +1,4 @@
#nullable enable
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Payment.Clients;
@ -21,8 +20,10 @@ public interface ICreateBitPayInvoiceForCreditCommand
public class CreateBitPayInvoiceForCreditCommand(
IBitPayClient bitPayClient,
GlobalSettings globalSettings,
ILogger<CreateBitPayInvoiceForCreditCommand> logger) : BillingCommand<CreateBitPayInvoiceForCreditCommand>(logger), ICreateBitPayInvoiceForCreditCommand
ILogger<CreateBitPayInvoiceForCreditCommand> logger) : BaseBillingCommand<CreateBitPayInvoiceForCreditCommand>(logger), ICreateBitPayInvoiceForCreditCommand
{
protected override Conflict DefaultConflict => new("We had a problem applying your account credit. Please contact support for assistance.");
public Task<BillingCommandResult<string>> Run(
ISubscriber subscriber,
decimal amount,

View File

@ -1,8 +1,8 @@
#nullable enable
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
@ -19,14 +19,26 @@ public interface IUpdateBillingAddressCommand
public class UpdateBillingAddressCommand(
ILogger<UpdateBillingAddressCommand> logger,
IStripeAdapter stripeAdapter) : BillingCommand<UpdateBillingAddressCommand>(logger), IUpdateBillingAddressCommand
ISubscriberService subscriberService,
IStripeAdapter stripeAdapter) : BaseBillingCommand<UpdateBillingAddressCommand>(logger), IUpdateBillingAddressCommand
{
protected override Conflict DefaultConflict =>
new("We had a problem updating your billing address. Please contact support for assistance.");
public Task<BillingCommandResult<BillingAddress>> Run(
ISubscriber subscriber,
BillingAddress billingAddress) => HandleAsync(() => subscriber.GetProductUsageType() switch
BillingAddress billingAddress) => HandleAsync(async () =>
{
ProductUsageType.Personal => UpdatePersonalBillingAddressAsync(subscriber, billingAddress),
ProductUsageType.Business => UpdateBusinessBillingAddressAsync(subscriber, billingAddress)
if (string.IsNullOrEmpty(subscriber.GatewayCustomerId))
{
await subscriberService.CreateStripeCustomer(subscriber);
}
return subscriber.GetProductUsageType() switch
{
ProductUsageType.Personal => await UpdatePersonalBillingAddressAsync(subscriber, billingAddress),
ProductUsageType.Business => await UpdateBusinessBillingAddressAsync(subscriber, billingAddress)
};
});
private async Task<BillingCommandResult<BillingAddress>> UpdatePersonalBillingAddressAsync(

View File

@ -1,5 +1,4 @@
#nullable enable
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Payment.Models;
@ -29,16 +28,22 @@ public class UpdatePaymentMethodCommand(
ILogger<UpdatePaymentMethodCommand> logger,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService) : BillingCommand<UpdatePaymentMethodCommand>(logger), IUpdatePaymentMethodCommand
ISubscriberService subscriberService) : BaseBillingCommand<UpdatePaymentMethodCommand>(logger), IUpdatePaymentMethodCommand
{
private readonly ILogger<UpdatePaymentMethodCommand> _logger = logger;
private static readonly Conflict _conflict = new("We had a problem updating your payment method. Please contact support for assistance.");
protected override Conflict DefaultConflict
=> new("We had a problem updating your payment method. Please contact support for assistance.");
public Task<BillingCommandResult<MaskedPaymentMethod>> Run(
ISubscriber subscriber,
TokenizedPaymentMethod paymentMethod,
BillingAddress? billingAddress) => HandleAsync(async () =>
{
if (string.IsNullOrEmpty(subscriber.GatewayCustomerId))
{
await subscriberService.CreateStripeCustomer(subscriber);
}
var customer = await subscriberService.GetCustomer(subscriber);
var result = paymentMethod.Type switch
@ -80,10 +85,10 @@ public class UpdatePaymentMethodCommand(
{
case 0:
_logger.LogError("{Command}: Could not find setup intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id);
return _conflict;
return DefaultConflict;
case > 1:
_logger.LogError("{Command}: Found more than one set up intent for subscriber's ({SubscriberID}) bank account", CommandName, subscriber.Id);
return _conflict;
return DefaultConflict;
}
var setupIntent = setupIntents.First();

View File

@ -1,5 +1,4 @@
#nullable enable
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Entities;
@ -19,12 +18,12 @@ public interface IVerifyBankAccountCommand
public class VerifyBankAccountCommand(
ILogger<VerifyBankAccountCommand> logger,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter) : BillingCommand<VerifyBankAccountCommand>(logger), IVerifyBankAccountCommand
IStripeAdapter stripeAdapter) : BaseBillingCommand<VerifyBankAccountCommand>(logger), IVerifyBankAccountCommand
{
private readonly ILogger<VerifyBankAccountCommand> _logger = logger;
private static readonly Conflict _conflict =
new("We had a problem verifying your bank account. Please contact support for assistance.");
protected override Conflict DefaultConflict
=> new("We had a problem verifying your bank account. Please contact support for assistance.");
public Task<BillingCommandResult<MaskedPaymentMethod>> Run(
ISubscriber subscriber,
@ -37,7 +36,7 @@ public class VerifyBankAccountCommand(
_logger.LogError(
"{Command}: Could not find setup intent to verify subscriber's ({SubscriberID}) bank account",
CommandName, subscriber.Id);
return _conflict;
return DefaultConflict;
}
await stripeAdapter.SetupIntentVerifyMicroDeposit(setupIntentId,

View File

@ -1,5 +1,4 @@
#nullable enable
using Stripe;
using Stripe;
namespace Bit.Core.Billing.Payment.Models;

View File

@ -1,7 +1,5 @@
#nullable enable
using System.Text.Json;
using System.Text.Json;
using System.Text.Json.Serialization;
using Bit.Core.Billing.Pricing.JSON;
using Braintree;
using OneOf;
using Stripe;
@ -83,32 +81,28 @@ public class MaskedPaymentMethod(OneOf<MaskedBankAccount, MaskedCard, MaskedPayP
public static MaskedPaymentMethod From(PayPalAccount payPalAccount) => new MaskedPayPalAccount { Email = payPalAccount.Email };
}
public class MaskedPaymentMethodJsonConverter : TypeReadingJsonConverter<MaskedPaymentMethod>
public class MaskedPaymentMethodJsonConverter : JsonConverter<MaskedPaymentMethod>
{
protected override string TypePropertyName => nameof(MaskedBankAccount.Type).ToLower();
private const string _typePropertyName = nameof(MaskedBankAccount.Type);
public override MaskedPaymentMethod? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
public override MaskedPaymentMethod Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var type = ReadType(reader);
var element = JsonElement.ParseValue(ref reader);
if (!element.TryGetProperty(options.PropertyNamingPolicy?.ConvertName(_typePropertyName) ?? _typePropertyName, out var typeProperty))
{
throw new JsonException(
$"Failed to deserialize {nameof(MaskedPaymentMethod)}: missing '{_typePropertyName}' property");
}
var type = typeProperty.GetString();
return type switch
{
"bankAccount" => JsonSerializer.Deserialize<MaskedBankAccount>(ref reader, options) switch
{
null => null,
var bankAccount => new MaskedPaymentMethod(bankAccount)
},
"card" => JsonSerializer.Deserialize<MaskedCard>(ref reader, options) switch
{
null => null,
var card => new MaskedPaymentMethod(card)
},
"payPal" => JsonSerializer.Deserialize<MaskedPayPalAccount>(ref reader, options) switch
{
null => null,
var payPal => new MaskedPaymentMethod(payPal)
},
_ => Skip(ref reader)
"bankAccount" => element.Deserialize<MaskedBankAccount>(options)!,
"card" => element.Deserialize<MaskedCard>(options)!,
"payPal" => element.Deserialize<MaskedPayPalAccount>(options)!,
_ => throw new JsonException($"Failed to deserialize {nameof(MaskedPaymentMethod)}: invalid '{_typePropertyName}' value - '{type}'")
};
}

View File

@ -1,5 +1,4 @@
#nullable enable
namespace Bit.Core.Billing.Payment.Models;
namespace Bit.Core.Billing.Payment.Models;
public record TokenizedPaymentMethod
{

View File

@ -1,5 +1,4 @@
#nullable enable
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;

View File

@ -1,5 +1,4 @@
#nullable enable
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
namespace Bit.Core.Billing.Payment.Queries;

View File

@ -1,5 +1,4 @@
#nullable enable
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Payment.Models;
@ -29,6 +28,11 @@ public class GetPaymentMethodQuery(
var customer = await subscriberService.GetCustomer(subscriber,
new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] });
if (customer == null)
{
return null;
}
if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId))
{
var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId);

View File

@ -1,35 +0,0 @@
using System.Text.Json;
using Bit.Core.Billing.Pricing.Models;
namespace Bit.Core.Billing.Pricing.JSON;
#nullable enable
public class FreeOrScalableDTOJsonConverter : TypeReadingJsonConverter<FreeOrScalableDTO>
{
public override FreeOrScalableDTO? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var type = ReadType(reader);
return type switch
{
"free" => JsonSerializer.Deserialize<FreeDTO>(ref reader, options) switch
{
null => null,
var free => new FreeOrScalableDTO(free)
},
"scalable" => JsonSerializer.Deserialize<ScalableDTO>(ref reader, options) switch
{
null => null,
var scalable => new FreeOrScalableDTO(scalable)
},
_ => null
};
}
public override void Write(Utf8JsonWriter writer, FreeOrScalableDTO value, JsonSerializerOptions options)
=> value.Switch(
free => JsonSerializer.Serialize(writer, free, options),
scalable => JsonSerializer.Serialize(writer, scalable, options)
);
}

View File

@ -1,40 +0,0 @@
using System.Text.Json;
using Bit.Core.Billing.Pricing.Models;
namespace Bit.Core.Billing.Pricing.JSON;
#nullable enable
internal class PurchasableDTOJsonConverter : TypeReadingJsonConverter<PurchasableDTO>
{
public override PurchasableDTO? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var type = ReadType(reader);
return type switch
{
"free" => JsonSerializer.Deserialize<FreeDTO>(ref reader, options) switch
{
null => null,
var free => new PurchasableDTO(free)
},
"packaged" => JsonSerializer.Deserialize<PackagedDTO>(ref reader, options) switch
{
null => null,
var packaged => new PurchasableDTO(packaged)
},
"scalable" => JsonSerializer.Deserialize<ScalableDTO>(ref reader, options) switch
{
null => null,
var scalable => new PurchasableDTO(scalable)
},
_ => null
};
}
public override void Write(Utf8JsonWriter writer, PurchasableDTO value, JsonSerializerOptions options)
=> value.Switch(
free => JsonSerializer.Serialize(writer, free, options),
packaged => JsonSerializer.Serialize(writer, packaged, options),
scalable => JsonSerializer.Serialize(writer, scalable, options)
);
}

View File

@ -1,36 +0,0 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using Bit.Core.Billing.Pricing.Models;
namespace Bit.Core.Billing.Pricing.JSON;
#nullable enable
public abstract class TypeReadingJsonConverter<T> : JsonConverter<T> where T : class
{
protected virtual string TypePropertyName => nameof(ScalableDTO.Type).ToLower();
protected string? ReadType(Utf8JsonReader reader)
{
while (reader.Read())
{
if (reader.CurrentDepth != 1 ||
reader.TokenType != JsonTokenType.PropertyName ||
reader.GetString()?.ToLower() != TypePropertyName)
{
continue;
}
reader.Read();
return reader.GetString();
}
return null;
}
protected T? Skip(ref Utf8JsonReader reader)
{
reader.Skip();
return null;
}
}

View File

@ -0,0 +1,7 @@
namespace Bit.Core.Billing.Pricing.Models;
public class Feature
{
public required string Name { get; set; }
public required string LookupKey { get; set; }
}

View File

@ -1,9 +0,0 @@
namespace Bit.Core.Billing.Pricing.Models;
#nullable enable
public class FeatureDTO
{
public string Name { get; set; } = null!;
public string LookupKey { get; set; } = null!;
}

View File

@ -0,0 +1,25 @@
namespace Bit.Core.Billing.Pricing.Models;
public class Plan
{
public required string LookupKey { get; set; }
public required string Name { get; set; }
public required string Tier { get; set; }
public string? Cadence { get; set; }
public int? LegacyYear { get; set; }
public bool Available { get; set; }
public required Feature[] Features { get; set; }
public required Purchasable Seats { get; set; }
public Scalable? ManagedSeats { get; set; }
public Scalable? Storage { get; set; }
public SecretsManagerPurchasables? SecretsManager { get; set; }
public int? TrialPeriodDays { get; set; }
public required string[] CanUpgradeTo { get; set; }
public required Dictionary<string, string> AdditionalData { get; set; }
}
public class SecretsManagerPurchasables
{
public required FreeOrScalable Seats { get; set; }
public required FreeOrScalable ServiceAccounts { get; set; }
}

View File

@ -1,27 +0,0 @@
namespace Bit.Core.Billing.Pricing.Models;
#nullable enable
public class PlanDTO
{
public string LookupKey { get; set; } = null!;
public string Name { get; set; } = null!;
public string Tier { get; set; } = null!;
public string? Cadence { get; set; }
public int? LegacyYear { get; set; }
public bool Available { get; set; }
public FeatureDTO[] Features { get; set; } = null!;
public PurchasableDTO Seats { get; set; } = null!;
public ScalableDTO? ManagedSeats { get; set; }
public ScalableDTO? Storage { get; set; }
public SecretsManagerPurchasablesDTO? SecretsManager { get; set; }
public int? TrialPeriodDays { get; set; }
public string[] CanUpgradeTo { get; set; } = null!;
public Dictionary<string, string> AdditionalData { get; set; } = null!;
}
public class SecretsManagerPurchasablesDTO
{
public FreeOrScalableDTO Seats { get; set; } = null!;
public FreeOrScalableDTO ServiceAccounts { get; set; } = null!;
}

View File

@ -0,0 +1,135 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using OneOf;
namespace Bit.Core.Billing.Pricing.Models;
[JsonConverter(typeof(PurchasableJsonConverter))]
public class Purchasable(OneOf<Free, Packaged, Scalable> input) : OneOfBase<Free, Packaged, Scalable>(input)
{
public static implicit operator Purchasable(Free free) => new(free);
public static implicit operator Purchasable(Packaged packaged) => new(packaged);
public static implicit operator Purchasable(Scalable scalable) => new(scalable);
public T? FromFree<T>(Func<Free, T> select, Func<Purchasable, T>? fallback = null) =>
IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default;
public T? FromPackaged<T>(Func<Packaged, T> select, Func<Purchasable, T>? fallback = null) =>
IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default;
public T? FromScalable<T>(Func<Scalable, T> select, Func<Purchasable, T>? fallback = null) =>
IsT2 ? select(AsT2) : fallback != null ? fallback(this) : default;
public bool IsFree => IsT0;
public bool IsPackaged => IsT1;
public bool IsScalable => IsT2;
}
internal class PurchasableJsonConverter : JsonConverter<Purchasable>
{
private static readonly string _typePropertyName = nameof(Free.Type).ToLower();
public override Purchasable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var element = JsonElement.ParseValue(ref reader);
if (!element.TryGetProperty(options.PropertyNamingPolicy?.ConvertName(_typePropertyName) ?? _typePropertyName, out var typeProperty))
{
throw new JsonException(
$"Failed to deserialize {nameof(Purchasable)}: missing '{_typePropertyName}' property");
}
var type = typeProperty.GetString();
return type switch
{
"free" => element.Deserialize<Free>(options)!,
"packaged" => element.Deserialize<Packaged>(options)!,
"scalable" => element.Deserialize<Scalable>(options)!,
_ => throw new JsonException($"Failed to deserialize {nameof(Purchasable)}: invalid '{_typePropertyName}' value - '{type}'"),
};
}
public override void Write(Utf8JsonWriter writer, Purchasable value, JsonSerializerOptions options)
=> value.Switch(
free => JsonSerializer.Serialize(writer, free, options),
packaged => JsonSerializer.Serialize(writer, packaged, options),
scalable => JsonSerializer.Serialize(writer, scalable, options)
);
}
[JsonConverter(typeof(FreeOrScalableJsonConverter))]
public class FreeOrScalable(OneOf<Free, Scalable> input) : OneOfBase<Free, Scalable>(input)
{
public static implicit operator FreeOrScalable(Free free) => new(free);
public static implicit operator FreeOrScalable(Scalable scalable) => new(scalable);
public T? FromFree<T>(Func<Free, T> select, Func<FreeOrScalable, T>? fallback = null) =>
IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default;
public T? FromScalable<T>(Func<Scalable, T> select, Func<FreeOrScalable, T>? fallback = null) =>
IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default;
public bool IsFree => IsT0;
public bool IsScalable => IsT1;
}
public class FreeOrScalableJsonConverter : JsonConverter<FreeOrScalable>
{
private static readonly string _typePropertyName = nameof(Free.Type).ToLower();
public override FreeOrScalable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var element = JsonElement.ParseValue(ref reader);
if (!element.TryGetProperty(options.PropertyNamingPolicy?.ConvertName(_typePropertyName) ?? _typePropertyName, out var typeProperty))
{
throw new JsonException(
$"Failed to deserialize {nameof(FreeOrScalable)}: missing '{_typePropertyName}' property");
}
var type = typeProperty.GetString();
return type switch
{
"free" => element.Deserialize<Free>(options)!,
"scalable" => element.Deserialize<Scalable>(options)!,
_ => throw new JsonException($"Failed to deserialize {nameof(FreeOrScalable)}: invalid '{_typePropertyName}' value - '{type}'"),
};
}
public override void Write(Utf8JsonWriter writer, FreeOrScalable value, JsonSerializerOptions options)
=> value.Switch(
free => JsonSerializer.Serialize(writer, free, options),
scalable => JsonSerializer.Serialize(writer, scalable, options)
);
}
public class Free
{
public int Quantity { get; set; }
public string Type => "free";
}
public class Packaged
{
public int Quantity { get; set; }
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
public AdditionalSeats? Additional { get; set; }
public string Type => "packaged";
public class AdditionalSeats
{
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
}
}
public class Scalable
{
public int Provided { get; set; }
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
public string Type => "scalable";
}

View File

@ -1,73 +0,0 @@
using System.Text.Json.Serialization;
using Bit.Core.Billing.Pricing.JSON;
using OneOf;
namespace Bit.Core.Billing.Pricing.Models;
#nullable enable
[JsonConverter(typeof(PurchasableDTOJsonConverter))]
public class PurchasableDTO(OneOf<FreeDTO, PackagedDTO, ScalableDTO> input) : OneOfBase<FreeDTO, PackagedDTO, ScalableDTO>(input)
{
public static implicit operator PurchasableDTO(FreeDTO free) => new(free);
public static implicit operator PurchasableDTO(PackagedDTO packaged) => new(packaged);
public static implicit operator PurchasableDTO(ScalableDTO scalable) => new(scalable);
public T? FromFree<T>(Func<FreeDTO, T> select, Func<PurchasableDTO, T>? fallback = null) =>
IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default;
public T? FromPackaged<T>(Func<PackagedDTO, T> select, Func<PurchasableDTO, T>? fallback = null) =>
IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default;
public T? FromScalable<T>(Func<ScalableDTO, T> select, Func<PurchasableDTO, T>? fallback = null) =>
IsT2 ? select(AsT2) : fallback != null ? fallback(this) : default;
public bool IsFree => IsT0;
public bool IsPackaged => IsT1;
public bool IsScalable => IsT2;
}
[JsonConverter(typeof(FreeOrScalableDTOJsonConverter))]
public class FreeOrScalableDTO(OneOf<FreeDTO, ScalableDTO> input) : OneOfBase<FreeDTO, ScalableDTO>(input)
{
public static implicit operator FreeOrScalableDTO(FreeDTO freeDTO) => new(freeDTO);
public static implicit operator FreeOrScalableDTO(ScalableDTO scalableDTO) => new(scalableDTO);
public T? FromFree<T>(Func<FreeDTO, T> select, Func<FreeOrScalableDTO, T>? fallback = null) =>
IsT0 ? select(AsT0) : fallback != null ? fallback(this) : default;
public T? FromScalable<T>(Func<ScalableDTO, T> select, Func<FreeOrScalableDTO, T>? fallback = null) =>
IsT1 ? select(AsT1) : fallback != null ? fallback(this) : default;
public bool IsFree => IsT0;
public bool IsScalable => IsT1;
}
public class FreeDTO
{
public int Quantity { get; set; }
public string Type => "free";
}
public class PackagedDTO
{
public int Quantity { get; set; }
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
public AdditionalSeats? Additional { get; set; }
public string Type => "packaged";
public class AdditionalSeats
{
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
}
}
public class ScalableDTO
{
public int Provided { get; set; }
public string StripePriceId { get; set; } = null!;
public decimal Price { get; set; }
public string Type => "scalable";
}

View File

@ -1,14 +1,12 @@
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing.Models;
using Bit.Core.Models.StaticStore;
#nullable enable
using Plan = Bit.Core.Billing.Pricing.Models.Plan;
namespace Bit.Core.Billing.Pricing;
public record PlanAdapter : Plan
public record PlanAdapter : Core.Models.StaticStore.Plan
{
public PlanAdapter(PlanDTO plan)
public PlanAdapter(Plan plan)
{
Type = ToPlanType(plan.LookupKey);
ProductTier = ToProductTierType(Type);
@ -88,7 +86,7 @@ public record PlanAdapter : Plan
_ => throw new BillingException() // TODO: Flesh out
};
private static PasswordManagerPlanFeatures ToPasswordManagerPlanFeatures(PlanDTO plan)
private static PasswordManagerPlanFeatures ToPasswordManagerPlanFeatures(Plan plan)
{
var stripePlanId = GetStripePlanId(plan.Seats);
var stripeSeatPlanId = GetStripeSeatPlanId(plan.Seats);
@ -128,7 +126,7 @@ public record PlanAdapter : Plan
};
}
private static SecretsManagerPlanFeatures ToSecretsManagerPlanFeatures(PlanDTO plan)
private static SecretsManagerPlanFeatures ToSecretsManagerPlanFeatures(Plan plan)
{
var seats = plan.SecretsManager!.Seats;
var serviceAccounts = plan.SecretsManager.ServiceAccounts;
@ -165,62 +163,62 @@ public record PlanAdapter : Plan
};
}
private static decimal? GetAdditionalPricePerServiceAccount(FreeOrScalableDTO freeOrScalable)
private static decimal? GetAdditionalPricePerServiceAccount(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromScalable(x => x.Price);
private static decimal GetBasePrice(PurchasableDTO purchasable)
private static decimal GetBasePrice(Purchasable purchasable)
=> purchasable.FromPackaged(x => x.Price);
private static int GetBaseSeats(FreeOrScalableDTO freeOrScalable)
private static int GetBaseSeats(FreeOrScalable freeOrScalable)
=> freeOrScalable.Match(
free => free.Quantity,
scalable => scalable.Provided);
private static int GetBaseSeats(PurchasableDTO purchasable)
private static int GetBaseSeats(Purchasable purchasable)
=> purchasable.Match(
free => free.Quantity,
packaged => packaged.Quantity,
scalable => scalable.Provided);
private static short GetBaseServiceAccount(FreeOrScalableDTO freeOrScalable)
private static short GetBaseServiceAccount(FreeOrScalable freeOrScalable)
=> freeOrScalable.Match(
free => (short)free.Quantity,
scalable => (short)scalable.Provided);
private static short? GetMaxSeats(PurchasableDTO purchasable)
private static short? GetMaxSeats(Purchasable purchasable)
=> purchasable.Match<short?>(
free => (short)free.Quantity,
packaged => (short)packaged.Quantity,
_ => null);
private static short? GetMaxSeats(FreeOrScalableDTO freeOrScalable)
private static short? GetMaxSeats(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromFree(x => (short)x.Quantity);
private static short? GetMaxServiceAccounts(FreeOrScalableDTO freeOrScalable)
private static short? GetMaxServiceAccounts(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromFree(x => (short)x.Quantity);
private static decimal GetSeatPrice(PurchasableDTO purchasable)
private static decimal GetSeatPrice(Purchasable purchasable)
=> purchasable.Match(
_ => 0,
packaged => packaged.Additional?.Price ?? 0,
scalable => scalable.Price);
private static decimal GetSeatPrice(FreeOrScalableDTO freeOrScalable)
private static decimal GetSeatPrice(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromScalable(x => x.Price);
private static string? GetStripePlanId(PurchasableDTO purchasable)
private static string? GetStripePlanId(Purchasable purchasable)
=> purchasable.FromPackaged(x => x.StripePriceId);
private static string? GetStripeSeatPlanId(PurchasableDTO purchasable)
private static string? GetStripeSeatPlanId(Purchasable purchasable)
=> purchasable.Match(
_ => null,
packaged => packaged.Additional?.StripePriceId,
scalable => scalable.StripePriceId);
private static string? GetStripeSeatPlanId(FreeOrScalableDTO freeOrScalable)
private static string? GetStripeSeatPlanId(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromScalable(x => x.StripePriceId);
private static string? GetStripeServiceAccountPlanId(FreeOrScalableDTO freeOrScalable)
private static string? GetStripeServiceAccountPlanId(FreeOrScalable freeOrScalable)
=> freeOrScalable.FromScalable(x => x.StripePriceId);
#endregion

View File

@ -1,7 +1,6 @@
using System.Net;
using System.Net.Http.Json;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing.Models;
using Bit.Core.Exceptions;
using Bit.Core.Services;
using Bit.Core.Settings;
@ -45,7 +44,7 @@ public class PricingClient(
if (response.IsSuccessStatusCode)
{
var plan = await response.Content.ReadFromJsonAsync<PlanDTO>();
var plan = await response.Content.ReadFromJsonAsync<Models.Plan>();
if (plan == null)
{
throw new BillingException(message: "Deserialization of Pricing Service response resulted in null");
@ -93,7 +92,7 @@ public class PricingClient(
if (response.IsSuccessStatusCode)
{
var plans = await response.Content.ReadFromJsonAsync<List<PlanDTO>>();
var plans = await response.Content.ReadFromJsonAsync<List<Models.Plan>>();
if (plans == null)
{
throw new BillingException(message: "Deserialization of Pricing Service response resulted in null");

View File

@ -36,6 +36,9 @@ public interface ISubscriberService
ISubscriber subscriber,
string paymentMethodNonce);
Task<Customer> CreateStripeCustomer(
ISubscriber subscriber);
/// <summary>
/// Retrieves a Stripe <see cref="Customer"/> using the <paramref name="subscriber"/>'s <see cref="ISubscriber.GatewayCustomerId"/> property.
/// </summary>

View File

@ -3,6 +3,7 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
@ -13,6 +14,7 @@ using Bit.Core.Billing.Tax.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
@ -27,14 +29,19 @@ using Subscription = Stripe.Subscription;
namespace Bit.Core.Billing.Services.Implementations;
using static StripeConstants;
public class SubscriberService(
IBraintreeGateway braintreeGateway,
IFeatureService featureService,
IGlobalSettings globalSettings,
ILogger<SubscriberService> logger,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ITaxService taxService) : ISubscriberService
ITaxService taxService,
IUserRepository userRepository) : ISubscriberService
{
public async Task CancelSubscription(
ISubscriber subscriber,
@ -146,6 +153,110 @@ public class SubscriberService(
throw new BillingException();
}
#nullable enable
public async Task<Customer> CreateStripeCustomer(ISubscriber subscriber)
{
if (!string.IsNullOrEmpty(subscriber.GatewayCustomerId))
{
throw new ConflictException("Subscriber already has a linked Stripe Customer");
}
var options = subscriber switch
{
Organization organization => new CustomerCreateOptions
{
Description = organization.DisplayBusinessName(),
Email = organization.BillingEmail,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = organization.SubscriberType(),
Value = Max30Characters(organization.DisplayName())
}
]
},
Metadata = new Dictionary<string, string>
{
[MetadataKeys.OrganizationId] = organization.Id.ToString(),
[MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion
}
},
Provider provider => new CustomerCreateOptions
{
Description = provider.DisplayBusinessName(),
Email = provider.BillingEmail,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = provider.SubscriberType(),
Value = Max30Characters(provider.DisplayName())
}
]
},
Metadata = new Dictionary<string, string>
{
[MetadataKeys.ProviderId] = provider.Id.ToString(),
[MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion
}
},
User user => new CustomerCreateOptions
{
Description = user.Name,
Email = user.Email,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = user.SubscriberType(),
Value = Max30Characters(user.SubscriberName())
}
]
},
Metadata = new Dictionary<string, string>
{
[MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion,
[MetadataKeys.UserId] = user.Id.ToString()
}
},
_ => throw new ArgumentOutOfRangeException(nameof(subscriber))
};
var customer = await stripeAdapter.CustomerCreateAsync(options);
switch (subscriber)
{
case Organization organization:
organization.Gateway = GatewayType.Stripe;
organization.GatewayCustomerId = customer.Id;
await organizationRepository.ReplaceAsync(organization);
break;
case Provider provider:
provider.Gateway = GatewayType.Stripe;
provider.GatewayCustomerId = customer.Id;
await providerRepository.ReplaceAsync(provider);
break;
case User user:
user.Gateway = GatewayType.Stripe;
user.GatewayCustomerId = customer.Id;
await userRepository.ReplaceAsync(user);
break;
}
return customer;
string? Max30Characters(string? input)
=> input?.Length <= 30 ? input : input?[..30];
}
#nullable disable
public async Task<Customer> GetCustomer(
ISubscriber subscriber,
CustomerGetOptions customerGetOptions = null)

View File

@ -1,5 +1,4 @@
#nullable enable
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Commands;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
@ -20,8 +19,11 @@ public class PreviewTaxAmountCommand(
ILogger<PreviewTaxAmountCommand> logger,
IPricingClient pricingClient,
IStripeAdapter stripeAdapter,
ITaxService taxService) : BillingCommand<PreviewTaxAmountCommand>(logger), IPreviewTaxAmountCommand
ITaxService taxService) : BaseBillingCommand<PreviewTaxAmountCommand>(logger), IPreviewTaxAmountCommand
{
protected override Conflict DefaultConflict
=> new("We had a problem calculating your tax obligation. Please contact support for assistance.");
public Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters)
=> HandleAsync<decimal>(async () =>
{

View File

@ -3,6 +3,7 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Payment.Commands;
using Bit.Core.Billing.Payment.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Services;
using Bit.Core.Test.Billing.Extensions;
using Microsoft.Extensions.Logging;
@ -16,14 +17,15 @@ using static StripeConstants;
public class UpdateBillingAddressCommandTests
{
private readonly IStripeAdapter _stripeAdapter;
private readonly ISubscriberService _subscriberService = Substitute.For<ISubscriberService>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly UpdateBillingAddressCommand _command;
public UpdateBillingAddressCommandTests()
{
_stripeAdapter = Substitute.For<IStripeAdapter>();
_command = new UpdateBillingAddressCommand(
Substitute.For<ILogger<UpdateBillingAddressCommand>>(),
_subscriberService,
_stripeAdapter);
}
@ -86,6 +88,66 @@ public class UpdateBillingAddressCommandTests
Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true));
}
[Fact]
public async Task Run_PersonalOrganization_NoCurrentCustomer_MakesCorrectInvocations_ReturnsBillingAddress()
{
var organization = new Organization
{
PlanType = PlanType.FamiliesAnnually,
GatewaySubscriptionId = "sub_123"
};
var input = new BillingAddress
{
Country = "US",
PostalCode = "12345",
Line1 = "123 Main St.",
Line2 = "Suite 100",
City = "New York",
State = "NY"
};
var customer = new Customer
{
Address = new Address
{
Country = "US",
PostalCode = "12345",
Line1 = "123 Main St.",
Line2 = "Suite 100",
City = "New York",
State = "NY"
},
Subscriptions = new StripeList<Subscription>
{
Data =
[
new Subscription
{
Id = organization.GatewaySubscriptionId,
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
}
]
}
};
_stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(options =>
options.Address.Matches(input) &&
options.HasExpansions("subscriptions")
)).Returns(customer);
var result = await _command.Run(organization, input);
Assert.True(result.IsT0);
var output = result.AsT0;
Assert.Equivalent(input, output);
await _subscriberService.Received(1).CreateStripeCustomer(organization);
await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true));
}
[Fact]
public async Task Run_BusinessOrganization_MakesCorrectInvocations_ReturnsBillingAddress()
{

View File

@ -45,7 +45,8 @@ public class UpdatePaymentMethodCommandTests
{
var organization = new Organization
{
Id = Guid.NewGuid()
Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
};
var customer = new Customer
@ -100,13 +101,75 @@ public class UpdatePaymentMethodCommandTests
}
[Fact]
public async Task Run_BankAccount_StripeToPayPal_MakesCorrectInvocations_ReturnsMaskedBankAccount()
public async Task Run_BankAccount_NoCurrentCustomer_MakesCorrectInvocations_ReturnsMaskedBankAccount()
{
var organization = new Organization
{
Id = Guid.NewGuid()
};
var customer = new Customer
{
Address = new Address
{
Country = "US",
PostalCode = "12345"
},
Metadata = new Dictionary<string, string>()
};
_subscriberService.GetCustomer(organization).Returns(customer);
const string token = "TOKEN";
var setupIntent = new SetupIntent
{
Id = "seti_123",
PaymentMethod =
new PaymentMethod
{
Type = "us_bank_account",
UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" }
},
NextAction = new SetupIntentNextAction
{
VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits()
},
Status = "requires_action"
};
_stripeAdapter.SetupIntentList(Arg.Is<SetupIntentListOptions>(options =>
options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]);
var result = await _command.Run(organization,
new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = token }, new BillingAddress
{
Country = "US",
PostalCode = "12345"
});
Assert.True(result.IsT0);
var maskedPaymentMethod = result.AsT0;
Assert.True(maskedPaymentMethod.IsT0);
var maskedBankAccount = maskedPaymentMethod.AsT0;
Assert.Equal("Chase", maskedBankAccount.BankName);
Assert.Equal("9999", maskedBankAccount.Last4);
Assert.False(maskedBankAccount.Verified);
await _subscriberService.Received(1).CreateStripeCustomer(organization);
await _setupIntentCache.Received(1).Set(organization.Id, setupIntent.Id);
}
[Fact]
public async Task Run_BankAccount_StripeToPayPal_MakesCorrectInvocations_ReturnsMaskedBankAccount()
{
var organization = new Organization
{
Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
};
var customer = new Customer
{
Address = new Address
@ -170,7 +233,8 @@ public class UpdatePaymentMethodCommandTests
{
var organization = new Organization
{
Id = Guid.NewGuid()
Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
};
var customer = new Customer
@ -227,7 +291,8 @@ public class UpdatePaymentMethodCommandTests
{
var organization = new Organization
{
Id = Guid.NewGuid()
Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
};
var customer = new Customer
@ -282,7 +347,8 @@ public class UpdatePaymentMethodCommandTests
{
var organization = new Organization
{
Id = Guid.NewGuid()
Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
};
var customer = new Customer
@ -343,7 +409,8 @@ public class UpdatePaymentMethodCommandTests
{
var organization = new Organization
{
Id = Guid.NewGuid()
Id = Guid.NewGuid(),
GatewayCustomerId = "cus_123"
};
var customer = new Customer

View File

@ -25,6 +25,27 @@ public class MaskedPaymentMethodTests
Assert.Equivalent(input.AsT0, output.AsT0);
}
[Fact]
public void Write_Read_BankAccount_WithOptions_Succeeds()
{
MaskedPaymentMethod input = new MaskedBankAccount
{
BankName = "Chase",
Last4 = "9999",
Verified = true
};
var jsonSerializerOptions = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase };
var json = JsonSerializer.Serialize(input, jsonSerializerOptions);
var output = JsonSerializer.Deserialize<MaskedPaymentMethod>(json, jsonSerializerOptions);
Assert.NotNull(output);
Assert.True(output.IsT0);
Assert.Equivalent(input.AsT0, output.AsT0);
}
[Fact]
public void Write_Read_Card_Succeeds()
{

View File

@ -8,6 +8,7 @@ using Bit.Core.Test.Billing.Extensions;
using Braintree;
using Microsoft.Extensions.Logging;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Stripe;
using Xunit;
using Customer = Stripe.Customer;
@ -35,6 +36,23 @@ public class GetPaymentMethodQueryTests
_subscriberService);
}
[Fact]
public async Task Run_NoCustomer_ReturnsNull()
{
var organization = new Organization
{
Id = Guid.NewGuid()
};
_subscriberService.GetCustomer(organization,
Arg.Is<CustomerGetOptions>(options =>
options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).ReturnsNull();
var maskedPaymentMethod = await _query.Run(organization);
Assert.Null(maskedPaymentMethod);
}
[Fact]
public async Task Run_NoPaymentMethod_ReturnsNull()
{