1
0
mirror of https://github.com/bitwarden/server.git synced 2025-07-03 00:52:49 -05:00

Start subscription for provider during setup process. (#3957)

This commit is contained in:
Alex Morask
2024-04-10 14:10:53 -04:00
committed by GitHub
parent 2c36784cda
commit 3cdfbdb22d
14 changed files with 749 additions and 17 deletions

View File

@ -4,4 +4,5 @@ public enum ProviderStatusType : byte
{
Pending = 0,
Created = 1,
Billable = 2
}

View File

@ -0,0 +1,11 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Models.Business;
namespace Bit.Core.Billing.Commands;
public interface IStartSubscriptionCommand
{
Task StartSubscription(
Provider provider,
TaxInfo taxInfo);
}

View File

@ -0,0 +1,209 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Repositories;
using Bit.Core.Enums;
using Bit.Core.Models.Business;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using Stripe;
using static Bit.Core.Billing.Utilities;
namespace Bit.Core.Billing.Commands.Implementations;
public class StartSubscriptionCommand(
IGlobalSettings globalSettings,
ILogger<StartSubscriptionCommand> logger,
IProviderPlanRepository providerPlanRepository,
IProviderRepository providerRepository,
IStripeAdapter stripeAdapter) : IStartSubscriptionCommand
{
public async Task StartSubscription(
Provider provider,
TaxInfo taxInfo)
{
ArgumentNullException.ThrowIfNull(provider);
ArgumentNullException.ThrowIfNull(taxInfo);
if (!string.IsNullOrEmpty(provider.GatewaySubscriptionId))
{
logger.LogWarning("Cannot start Provider subscription - Provider ({ID}) already has a {FieldName}", provider.Id, nameof(provider.GatewaySubscriptionId));
throw ContactSupport();
}
if (string.IsNullOrEmpty(taxInfo.BillingAddressCountry) ||
string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode))
{
logger.LogError("Cannot start Provider subscription - Both the Provider's ({ID}) country and postal code are required", provider.Id);
throw ContactSupport();
}
var customer = await GetOrCreateCustomerAsync(provider, taxInfo);
if (taxInfo.BillingAddressCountry == "US" && customer.Tax is not { AutomaticTax: StripeConstants.AutomaticTaxStatus.Supported })
{
logger.LogError("Cannot start Provider subscription - Provider's ({ProviderID}) Stripe customer ({CustomerID}) is in the US and does not support automatic tax", provider.Id, customer.Id);
throw ContactSupport();
}
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
if (providerPlans == null || providerPlans.Count == 0)
{
logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured plans", provider.Id);
throw ContactSupport();
}
var subscriptionItemOptionsList = new List<SubscriptionItemOptions>();
var teamsProviderPlan =
providerPlans.SingleOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly);
if (teamsProviderPlan == null)
{
logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured Teams Monthly plan", provider.Id);
throw ContactSupport();
}
var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
subscriptionItemOptionsList.Add(new SubscriptionItemOptions
{
Price = teamsPlan.PasswordManager.StripeSeatPlanId,
Quantity = teamsProviderPlan.SeatMinimum
});
var enterpriseProviderPlan =
providerPlans.SingleOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly);
if (enterpriseProviderPlan == null)
{
logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured Enterprise Monthly plan", provider.Id);
throw ContactSupport();
}
var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly);
subscriptionItemOptionsList.Add(new SubscriptionItemOptions
{
Price = enterprisePlan.PasswordManager.StripeSeatPlanId,
Quantity = enterpriseProviderPlan.SeatMinimum
});
var subscriptionCreateOptions = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
CollectionMethod = StripeConstants.CollectionMethod.SendInvoice,
Customer = customer.Id,
DaysUntilDue = 30,
Items = subscriptionItemOptionsList,
Metadata = new Dictionary<string, string>
{
{ "providerId", provider.Id.ToString() }
},
OffSession = true,
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations
};
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
provider.GatewaySubscriptionId = subscription.Id;
if (subscription.Status == StripeConstants.SubscriptionStatus.Incomplete)
{
await providerRepository.ReplaceAsync(provider);
logger.LogError("Started incomplete Provider ({ProviderID}) subscription ({SubscriptionID})", provider.Id, subscription.Id);
throw ContactSupport();
}
provider.Status = ProviderStatusType.Billable;
await providerRepository.ReplaceAsync(provider);
}
// ReSharper disable once SuggestBaseTypeForParameter
private async Task<Customer> GetOrCreateCustomerAsync(
Provider provider,
TaxInfo taxInfo)
{
if (!string.IsNullOrEmpty(provider.GatewayCustomerId))
{
var existingCustomer = await stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, new CustomerGetOptions
{
Expand = ["tax"]
});
if (existingCustomer != null)
{
return existingCustomer;
}
logger.LogError("Cannot start Provider subscription - Provider's ({ProviderID}) {CustomerIDFieldName} did not relate to a Stripe customer", provider.Id, nameof(provider.GatewayCustomerId));
throw ContactSupport();
}
var providerDisplayName = provider.DisplayName();
var customerCreateOptions = new CustomerCreateOptions
{
Address = new AddressOptions
{
Country = taxInfo.BillingAddressCountry,
PostalCode = taxInfo.BillingAddressPostalCode,
Line1 = taxInfo.BillingAddressLine1,
Line2 = taxInfo.BillingAddressLine2,
City = taxInfo.BillingAddressCity,
State = taxInfo.BillingAddressState
},
Coupon = "msp-discount-35",
Description = provider.DisplayBusinessName(),
Email = provider.BillingEmail,
Expand = ["tax"],
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = provider.SubscriberType(),
Value = providerDisplayName.Length <= 30
? providerDisplayName
: providerDisplayName[..30]
}
]
},
Metadata = new Dictionary<string, string>
{
{ "region", globalSettings.BaseServiceUri.CloudRegion }
},
TaxIdData = taxInfo.HasTaxId ?
[
new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber }
]
: null
};
var createdCustomer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
provider.GatewayCustomerId = createdCustomer.Id;
await providerRepository.ReplaceAsync(provider);
return createdCustomer;
}
}

View File

@ -0,0 +1,37 @@
namespace Bit.Core.Billing.Constants;
public static class StripeConstants
{
public static class AutomaticTaxStatus
{
public const string Failed = "failed";
public const string NotCollecting = "not_collecting";
public const string Supported = "supported";
public const string UnrecognizedLocation = "unrecognized_location";
}
public static class CollectionMethod
{
public const string ChargeAutomatically = "charge_automatically";
public const string SendInvoice = "send_invoice";
}
public static class ProrationBehavior
{
public const string AlwaysInvoice = "always_invoice";
public const string CreateProrations = "create_prorations";
public const string None = "none";
}
public static class SubscriptionStatus
{
public const string Trialing = "trialing";
public const string Active = "active";
public const string Incomplete = "incomplete";
public const string IncompleteExpired = "incomplete_expired";
public const string PastDue = "past_due";
public const string Canceled = "canceled";
public const string Unpaid = "unpaid";
public const string Paused = "paused";
}
}

View File

@ -1,9 +0,0 @@
namespace Bit.Core.Billing.Constants;
public static class StripeCustomerAutomaticTaxStatus
{
public const string Failed = "failed";
public const string NotCollecting = "not_collecting";
public const string Supported = "supported";
public const string UnrecognizedLocation = "unrecognized_location";
}

View File

@ -19,5 +19,6 @@ public static class ServiceCollectionExtensions
services.AddTransient<IAssignSeatsToClientOrganizationCommand, AssignSeatsToClientOrganizationCommand>();
services.AddTransient<ICancelSubscriptionCommand, CancelSubscriptionCommand>();
services.AddTransient<IRemovePaymentMethodCommand, RemovePaymentMethodCommand>();
services.AddTransient<IStartSubscriptionCommand, StartSubscriptionCommand>();
}
}

View File

@ -1923,7 +1923,7 @@ public class StripePaymentService : IPaymentService
/// <param name="customer"></param>
/// <returns></returns>
private static bool CustomerHasTaxLocationVerified(Customer customer) =>
customer?.Tax?.AutomaticTax == StripeCustomerAutomaticTaxStatus.Supported;
customer?.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported;
// We are taking only first 30 characters of the SubscriberName because stripe provide
// for 30 characters for custom_fields,see the link: https://stripe.com/docs/api/invoices/create