1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-04 20:50:21 -05:00

[PM-18794] Allow provider payment method (#5500)

* Add PaymentSource to ProviderSubscriptionResponse

* Add UpdatePaymentMethod to ProviderBillingController

* Add GetTaxInformation to ProviderBillingController

* Add VerifyBankAccount to ProviderBillingController

* Add feature flag
This commit is contained in:
Alex Morask 2025-03-14 11:33:24 -04:00 committed by GitHub
parent 2df4076a6b
commit 7daf6cfad4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 113 additions and 4 deletions

View File

@ -628,6 +628,19 @@ public class ProviderBillingService(
} }
} }
public async Task UpdatePaymentMethod(
Provider provider,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation)
{
await Task.WhenAll(
subscriberService.UpdatePaymentSource(provider, tokenizedPaymentSource),
subscriberService.UpdateTaxInformation(provider, taxInformation));
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically });
}
public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command) public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command)
{ {
if (command.Configuration.Any(x => x.SeatsMinimum < 0)) if (command.Configuration.Any(x => x.SeatsMinimum < 0))

View File

@ -1,5 +1,6 @@
using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses; using Bit.Api.Billing.Models.Responses;
using Bit.Core;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
@ -20,6 +21,7 @@ namespace Bit.Api.Billing.Controllers;
[Authorize("Application")] [Authorize("Application")]
public class ProviderBillingController( public class ProviderBillingController(
ICurrentContext currentContext, ICurrentContext currentContext,
IFeatureService featureService,
ILogger<BaseProviderController> logger, ILogger<BaseProviderController> logger,
IPricingClient pricingClient, IPricingClient pricingClient,
IProviderBillingService providerBillingService, IProviderBillingService providerBillingService,
@ -71,6 +73,65 @@ public class ProviderBillingController(
"text/csv"); "text/csv");
} }
[HttpPut("payment-method")]
public async Task<IResult> UpdatePaymentMethodAsync(
[FromRoute] Guid providerId,
[FromBody] UpdatePaymentMethodRequestBody requestBody)
{
var allowProviderPaymentMethod = featureService.IsEnabled(FeatureFlagKeys.PM18794_ProviderPaymentMethod);
if (!allowProviderPaymentMethod)
{
return TypedResults.NotFound();
}
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var tokenizedPaymentSource = requestBody.PaymentSource.ToDomain();
var taxInformation = requestBody.TaxInformation.ToDomain();
await providerBillingService.UpdatePaymentMethod(
provider,
tokenizedPaymentSource,
taxInformation);
return TypedResults.Ok();
}
[HttpPost("payment-method/verify-bank-account")]
public async Task<IResult> VerifyBankAccountAsync(
[FromRoute] Guid providerId,
[FromBody] VerifyBankAccountRequestBody requestBody)
{
var allowProviderPaymentMethod = featureService.IsEnabled(FeatureFlagKeys.PM18794_ProviderPaymentMethod);
if (!allowProviderPaymentMethod)
{
return TypedResults.NotFound();
}
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
if (requestBody.DescriptorCode.Length != 6 || !requestBody.DescriptorCode.StartsWith("SM"))
{
return Error.BadRequest("Statement descriptor should be a 6-character value that starts with 'SM'");
}
await subscriberService.VerifyBankAccount(provider, requestBody.DescriptorCode);
return TypedResults.Ok();
}
[HttpGet("subscription")] [HttpGet("subscription")]
public async Task<IResult> GetSubscriptionAsync([FromRoute] Guid providerId) public async Task<IResult> GetSubscriptionAsync([FromRoute] Guid providerId)
{ {
@ -102,12 +163,32 @@ public class ProviderBillingController(
var subscriptionSuspension = await GetSubscriptionSuspensionAsync(stripeAdapter, subscription); var subscriptionSuspension = await GetSubscriptionSuspensionAsync(stripeAdapter, subscription);
var paymentSource = await subscriberService.GetPaymentSource(provider);
var response = ProviderSubscriptionResponse.From( var response = ProviderSubscriptionResponse.From(
subscription, subscription,
configuredProviderPlans, configuredProviderPlans,
taxInformation, taxInformation,
subscriptionSuspension, subscriptionSuspension,
provider); provider,
paymentSource);
return TypedResults.Ok(response);
}
[HttpGet("tax-information")]
public async Task<IResult> GetTaxInformationAsync([FromRoute] Guid providerId)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var taxInformation = await subscriberService.GetTaxInformation(provider);
var response = TaxInformationResponse.From(taxInformation);
return TypedResults.Ok(response); return TypedResults.Ok(response);
} }

View File

@ -16,7 +16,8 @@ public record ProviderSubscriptionResponse(
TaxInformation TaxInformation, TaxInformation TaxInformation,
DateTime? CancelAt, DateTime? CancelAt,
SubscriptionSuspension Suspension, SubscriptionSuspension Suspension,
ProviderType ProviderType) ProviderType ProviderType,
PaymentSource PaymentSource)
{ {
private const string _annualCadence = "Annual"; private const string _annualCadence = "Annual";
private const string _monthlyCadence = "Monthly"; private const string _monthlyCadence = "Monthly";
@ -26,7 +27,8 @@ public record ProviderSubscriptionResponse(
ICollection<ConfiguredProviderPlan> providerPlans, ICollection<ConfiguredProviderPlan> providerPlans,
TaxInformation taxInformation, TaxInformation taxInformation,
SubscriptionSuspension subscriptionSuspension, SubscriptionSuspension subscriptionSuspension,
Provider provider) Provider provider,
PaymentSource paymentSource)
{ {
var providerPlanResponses = providerPlans var providerPlanResponses = providerPlans
.Select(providerPlan => .Select(providerPlan =>
@ -57,7 +59,8 @@ public record ProviderSubscriptionResponse(
taxInformation, taxInformation,
subscription.CancelAt, subscription.CancelAt,
subscriptionSuspension, subscriptionSuspension,
provider.Type); provider.Type,
paymentSource);
} }
} }

View File

@ -95,5 +95,16 @@ public interface IProviderBillingService
Task<Subscription> SetupSubscription( Task<Subscription> SetupSubscription(
Provider provider); Provider provider);
/// <summary>
/// Updates the <paramref name="provider"/>'s payment source and tax information and then sets their subscription's collection_method to be "charge_automatically".
/// </summary>
/// <param name="provider">The <paramref name="provider"/> to update the payment source and tax information for.</param>
/// <param name="tokenizedPaymentSource">The tokenized payment source (ex. Credit Card) to attach to the <paramref name="provider"/>.</param>
/// <param name="taxInformation">The <paramref name="provider"/>'s updated tax information.</param>
Task UpdatePaymentMethod(
Provider provider,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation);
Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command); Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command);
} }

View File

@ -175,6 +175,7 @@ public static class FeatureFlagKeys
public const string WebPush = "web-push"; public const string WebPush = "web-push";
public const string AndroidImportLoginsFlow = "import-logins-flow"; public const string AndroidImportLoginsFlow = "import-logins-flow";
public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features"; public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features";
public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method";
public static List<string> GetAllKeys() public static List<string> GetAllKeys()
{ {