using System.Globalization; using Bit.Commercial.Core.Billing.Models; using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Entities; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Contracts; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using CsvHelper; using Microsoft.Extensions.Logging; using Stripe; namespace Bit.Commercial.Core.Billing; public class ProviderBillingService( IEventService eventService, IGlobalSettings globalSettings, ILogger logger, IOrganizationRepository organizationRepository, IPaymentService paymentService, IPricingClient pricingClient, IProviderInvoiceItemRepository providerInvoiceItemRepository, IProviderOrganizationRepository providerOrganizationRepository, IProviderPlanRepository providerPlanRepository, IProviderUserRepository providerUserRepository, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, ITaxService taxService) : IProviderBillingService { [RequireFeature(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal)] public async Task AddExistingOrganization( Provider provider, Organization organization, string key) { await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); var subscription = await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId, new SubscriptionCancelOptions { CancellationDetails = new SubscriptionCancellationDetailsOptions { Comment = $"Organization was added to Provider with ID {provider.Id}" }, InvoiceNow = true, Prorate = true, Expand = ["latest_invoice", "test_clock"] }); var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow; var wasTrialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now; if (!wasTrialing && subscription.LatestInvoice.Status == StripeConstants.InvoiceStatus.Draft) { await stripeAdapter.InvoiceFinalizeInvoiceAsync(subscription.LatestInvoiceId, new InvoiceFinalizeOptions { AutoAdvance = true }); } var managedPlanType = await GetManagedPlanTypeAsync(provider, organization); var plan = await pricingClient.GetPlanOrThrow(managedPlanType); organization.Plan = plan.Name; organization.PlanType = plan.Type; organization.MaxCollections = plan.PasswordManager.MaxCollections; organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; organization.UsePolicies = plan.HasPolicies; organization.UseSso = plan.HasSso; organization.UseGroups = plan.HasGroups; organization.UseEvents = plan.HasEvents; organization.UseDirectory = plan.HasDirectory; organization.UseTotp = plan.HasTotp; organization.Use2fa = plan.Has2fa; organization.UseApi = plan.HasApi; organization.UseResetPassword = plan.HasResetPassword; organization.SelfHost = plan.HasSelfHost; organization.UsersGetPremium = plan.UsersGetPremium; organization.UseCustomPermissions = plan.HasCustomPermissions; organization.UseScim = plan.HasScim; organization.UseKeyConnector = plan.HasKeyConnector; organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; organization.BillingEmail = provider.BillingEmail!; organization.GatewaySubscriptionId = null; organization.ExpirationDate = null; organization.MaxAutoscaleSeats = null; organization.Status = OrganizationStatusType.Managed; var providerOrganization = new ProviderOrganization { ProviderId = provider.Id, OrganizationId = organization.Id, Key = key }; /* * We have to scale the provider's seats before the ProviderOrganization * row is inserted so the added organization's seats don't get double counted. */ await ScaleSeats(provider, organization.PlanType, organization.Seats!.Value); await Task.WhenAll( organizationRepository.ReplaceAsync(organization), providerOrganizationRepository.CreateAsync(providerOrganization) ); var clientCustomer = await subscriberService.GetCustomer(organization); if (clientCustomer.Balance != 0) { await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId, new CustomerBalanceTransactionCreateOptions { Amount = clientCustomer.Balance, Currency = "USD", Description = $"Unused, prorated time for client organization with ID {organization.Id}." }); } await eventService.LogProviderOrganizationEventAsync( providerOrganization, EventType.ProviderOrganization_Added); } public async Task ChangePlan(ChangeProviderPlanCommand command) { var plan = await providerPlanRepository.GetByIdAsync(command.ProviderPlanId); if (plan == null) { throw new BadRequestException("Provider plan not found."); } if (plan.PlanType == command.NewPlan) { return; } var oldPlanConfiguration = await pricingClient.GetPlanOrThrow(plan.PlanType); var newPlanConfiguration = await pricingClient.GetPlanOrThrow(command.NewPlan); plan.PlanType = command.NewPlan; await providerPlanRepository.ReplaceAsync(plan); Subscription subscription; try { subscription = await stripeAdapter.ProviderSubscriptionGetAsync(command.GatewaySubscriptionId, plan.ProviderId); } catch (InvalidOperationException) { throw new ConflictException("Subscription not found."); } var oldSubscriptionItem = subscription.Items.SingleOrDefault(x => x.Price.Id == oldPlanConfiguration.PasswordManager.StripeProviderPortalSeatPlanId); var updateOptions = new SubscriptionUpdateOptions { Items = [ new SubscriptionItemOptions { Price = newPlanConfiguration.PasswordManager.StripeProviderPortalSeatPlanId, Quantity = oldSubscriptionItem!.Quantity }, new SubscriptionItemOptions { Id = oldSubscriptionItem.Id, Deleted = true } ] }; await stripeAdapter.SubscriptionUpdateAsync(command.GatewaySubscriptionId, updateOptions); // Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId) // 1. Retrieve PlanType and PlanName for ProviderPlan // 2. Assign PlanType & PlanName to Organization var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(plan.ProviderId); foreach (var providerOrganization in providerOrganizations) { var organization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId); if (organization == null) { throw new ConflictException($"Organization '{providerOrganization.Id}' not found."); } organization.PlanType = command.NewPlan; organization.Plan = newPlanConfiguration.Name; await organizationRepository.ReplaceAsync(organization); } } public async Task CreateCustomerForClientOrganization( Provider provider, Organization organization) { ArgumentNullException.ThrowIfNull(provider); ArgumentNullException.ThrowIfNull(organization); if (!string.IsNullOrEmpty(organization.GatewayCustomerId)) { logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, nameof(organization.GatewayCustomerId)); return; } var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions { Expand = ["tax_ids"] }); var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); var organizationDisplayName = organization.DisplayName(); var customerCreateOptions = new CustomerCreateOptions { Address = new AddressOptions { Country = providerCustomer.Address?.Country, PostalCode = providerCustomer.Address?.PostalCode, Line1 = providerCustomer.Address?.Line1, Line2 = providerCustomer.Address?.Line2, City = providerCustomer.Address?.City, State = providerCustomer.Address?.State }, Name = organizationDisplayName, Description = $"{provider.Name} Client Organization", Email = provider.BillingEmail, InvoiceSettings = new CustomerInvoiceSettingsOptions { CustomFields = [ new CustomerInvoiceSettingsCustomFieldOptions { Name = organization.SubscriberType(), Value = organizationDisplayName.Length <= 30 ? organizationDisplayName : organizationDisplayName[..30] } ] }, Metadata = new Dictionary { { "region", globalSettings.BaseServiceUri.CloudRegion } }, TaxIdData = providerTaxId == null ? null : [ new CustomerTaxIdDataOptions { Type = providerTaxId.Type, Value = providerTaxId.Value } ] }; var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); organization.GatewayCustomerId = customer.Id; await organizationRepository.ReplaceAsync(organization); } public async Task GenerateClientInvoiceReport( string invoiceId) { ArgumentException.ThrowIfNullOrEmpty(invoiceId); var invoiceItems = await providerInvoiceItemRepository.GetByInvoiceId(invoiceId); if (invoiceItems.Count == 0) { logger.LogError("No provider invoice item records were found for invoice ({InvoiceID})", invoiceId); return null; } var csvRows = invoiceItems.Select(ProviderClientInvoiceReportRow.From); using var memoryStream = new MemoryStream(); await using var streamWriter = new StreamWriter(memoryStream); await using var csvWriter = new CsvWriter(streamWriter, CultureInfo.CurrentCulture); await csvWriter.WriteRecordsAsync(csvRows); await streamWriter.FlushAsync(); memoryStream.Seek(0, SeekOrigin.Begin); return memoryStream.ToArray(); } [RequireFeature(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal)] public async Task> GetAddableOrganizations( Provider provider, Guid userId) { var providerUser = await providerUserRepository.GetByProviderUserAsync(provider.Id, userId); if (providerUser is not { Status: ProviderUserStatusType.Confirmed }) { throw new UnauthorizedAccessException(); } var candidates = await organizationRepository.GetAddableToProviderByUserIdAsync(userId, provider.Type); var active = (await Task.WhenAll(candidates.Select(async organization => { var subscription = await subscriberService.GetSubscription(organization); return (organization, subscription); }))) .Where(pair => pair.subscription is { Status: StripeConstants.SubscriptionStatus.Active or StripeConstants.SubscriptionStatus.Trialing or StripeConstants.SubscriptionStatus.PastDue }).ToList(); if (active.Count == 0) { return []; } return await Task.WhenAll(active.Select(async pair => { var (organization, _) = pair; var planName = await DerivePlanName(provider, organization); var addable = new AddableOrganization( organization.Id, organization.Name, planName, organization.Seats!.Value); if (providerUser.Type != ProviderUserType.ServiceUser) { return addable; } var applicablePlanType = await GetManagedPlanTypeAsync(provider, organization); var requiresPurchase = await SeatAdjustmentResultsInPurchase(provider, applicablePlanType, organization.Seats!.Value); return addable with { Disabled = requiresPurchase }; })); async Task DerivePlanName(Provider localProvider, Organization localOrganization) { if (localProvider.Type == ProviderType.Msp) { return localOrganization.PlanType switch { var planType when PlanConstants.EnterprisePlanTypes.Contains(planType) => "Enterprise", var planType when PlanConstants.TeamsPlanTypes.Contains(planType) => "Teams", _ => throw new BillingException() }; } var plan = await pricingClient.GetPlanOrThrow(localOrganization.PlanType); return plan.Name; } } public async Task ScaleSeats( Provider provider, PlanType planType, int seatAdjustment) { var providerPlan = await GetProviderPlanAsync(provider, planType); var seatMinimum = providerPlan.SeatMinimum ?? 0; var currentlyAssignedSeatTotal = await GetAssignedSeatTotalAsync(provider, planType); var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment; var update = CurrySeatScalingUpdate( provider, providerPlan, newlyAssignedSeatTotal); /* * Below the limit => Below the limit: * No subscription update required. We can safely update the provider's allocated seats. */ if (currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal <= seatMinimum) { providerPlan.AllocatedSeats = newlyAssignedSeatTotal; await providerPlanRepository.ReplaceAsync(providerPlan); } /* * Below the limit => Above the limit: * We have to scale the subscription up from the seat minimum to the newly assigned seat total. */ else if (currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal > seatMinimum) { await update( seatMinimum, newlyAssignedSeatTotal); } /* * Above the limit => Above the limit: * We have to scale the subscription from the currently assigned seat total to the newly assigned seat total. */ else if (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum) { await update( currentlyAssignedSeatTotal, newlyAssignedSeatTotal); } /* * Above the limit => Below the limit: * We have to scale the subscription down from the currently assigned seat total to the seat minimum. */ else if (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal <= seatMinimum) { await update( currentlyAssignedSeatTotal, seatMinimum); } } public async Task SeatAdjustmentResultsInPurchase( Provider provider, PlanType planType, int seatAdjustment) { var providerPlan = await GetProviderPlanAsync(provider, planType); var seatMinimum = providerPlan.SeatMinimum; var currentlyAssignedSeatTotal = await GetAssignedSeatTotalAsync(provider, planType); var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment; return // Below the limit to above the limit (currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal > seatMinimum) || // Above the limit to further above the limit (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > currentlyAssignedSeatTotal); } public async Task SetupCustomer( Provider provider, TaxInfo taxInfo) { if (taxInfo is not { BillingAddressCountry: not null and not "", BillingAddressPostalCode: not null and not "" }) { logger.LogError("Cannot create customer for provider ({ProviderID}) without both a country and postal code", provider.Id); throw new BillingException(); } var options = new CustomerCreateOptions { Address = new AddressOptions { Country = taxInfo.BillingAddressCountry, PostalCode = taxInfo.BillingAddressPostalCode, Line1 = taxInfo.BillingAddressLine1, Line2 = taxInfo.BillingAddressLine2, City = taxInfo.BillingAddressCity, State = taxInfo.BillingAddressState }, Description = provider.DisplayBusinessName(), Email = provider.BillingEmail, InvoiceSettings = new CustomerInvoiceSettingsOptions { CustomFields = [ new CustomerInvoiceSettingsCustomFieldOptions { Name = provider.SubscriberType(), Value = provider.DisplayName()?.Length <= 30 ? provider.DisplayName() : provider.DisplayName()?[..30] } ] }, Metadata = new Dictionary { { "region", globalSettings.BaseServiceUri.CloudRegion } } }; if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber)) { var taxIdType = taxService.GetStripeTaxCode( taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber); if (taxIdType == null) { logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber); throw new BadRequestException("billingTaxIdTypeInferenceError"); } options.TaxIdData = [ new CustomerTaxIdDataOptions { Type = taxIdType, Value = taxInfo.TaxIdNumber } ]; } if (!string.IsNullOrEmpty(provider.DiscountId)) { options.Coupon = provider.DiscountId; } try { return await stripeAdapter.CustomerCreateAsync(options); } catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.TaxIdInvalid) { throw new BadRequestException("Your tax ID wasn't recognized for your selected country. Please ensure your country and tax ID are valid."); } } public async Task SetupSubscription( Provider provider) { ArgumentNullException.ThrowIfNull(provider); var customer = await subscriberService.GetCustomerOrThrow(provider); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); if (providerPlans == null || providerPlans.Count == 0) { logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans", provider.Id); throw new BillingException(); } var subscriptionItemOptionsList = new List(); foreach (var providerPlan in providerPlans) { var plan = await pricingClient.GetPlanOrThrow(providerPlan.PlanType); if (!providerPlan.IsConfigured()) { logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured {ProviderName} plan", provider.Id, plan.Name); throw new BillingException(); } subscriptionItemOptionsList.Add(new SubscriptionItemOptions { Price = plan.PasswordManager.StripeProviderPortalSeatPlanId, Quantity = providerPlan.SeatMinimum }); } var subscriptionCreateOptions = new SubscriptionCreateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, Customer = customer.Id, DaysUntilDue = 30, Items = subscriptionItemOptionsList, Metadata = new Dictionary { { "providerId", provider.Id.ToString() } }, OffSession = true, ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations }; try { var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); if (subscription.Status == StripeConstants.SubscriptionStatus.Active) { return subscription; } logger.LogError( "Newly created provider ({ProviderID}) subscription ({SubscriptionID}) has inactive status: {Status}", provider.Id, subscription.Id, subscription.Status); throw new BillingException(); } catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) { throw new BadRequestException("Your location wasn't recognized. Please ensure your country and postal code are valid."); } } public async Task UpdatePaymentMethod( Provider provider, TokenizedPaymentSource tokenizedPaymentSource, TaxInformation taxInformation) { await Task.WhenAll( subscriberService.UpdatePaymentSource(provider, tokenizedPaymentSource), subscriberService.UpdateTaxInformation(provider, taxInformation)); await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically }); } public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command) { if (command.Configuration.Any(x => x.SeatsMinimum < 0)) { throw new BadRequestException("Provider seat minimums must be at least 0."); } Subscription subscription; try { subscription = await stripeAdapter.ProviderSubscriptionGetAsync(command.GatewaySubscriptionId, command.Id); } catch (InvalidOperationException) { throw new ConflictException("Subscription not found."); } var subscriptionItemOptionsList = new List(); var providerPlans = await providerPlanRepository.GetByProviderId(command.Id); foreach (var newPlanConfiguration in command.Configuration) { var providerPlan = providerPlans.Single(providerPlan => providerPlan.PlanType == newPlanConfiguration.Plan); if (providerPlan.SeatMinimum != newPlanConfiguration.SeatsMinimum) { var newPlan = await pricingClient.GetPlanOrThrow(newPlanConfiguration.Plan); var priceId = newPlan.PasswordManager.StripeProviderPortalSeatPlanId; var subscriptionItem = subscription.Items.First(item => item.Price.Id == priceId); if (providerPlan.PurchasedSeats == 0) { if (providerPlan.AllocatedSeats > newPlanConfiguration.SeatsMinimum) { providerPlan.PurchasedSeats = providerPlan.AllocatedSeats - newPlanConfiguration.SeatsMinimum; subscriptionItemOptionsList.Add(new SubscriptionItemOptions { Id = subscriptionItem.Id, Price = priceId, Quantity = providerPlan.AllocatedSeats }); } else { subscriptionItemOptionsList.Add(new SubscriptionItemOptions { Id = subscriptionItem.Id, Price = priceId, Quantity = newPlanConfiguration.SeatsMinimum }); } } else { var totalSeats = providerPlan.SeatMinimum + providerPlan.PurchasedSeats; if (newPlanConfiguration.SeatsMinimum <= totalSeats) { providerPlan.PurchasedSeats = totalSeats - newPlanConfiguration.SeatsMinimum; } else { providerPlan.PurchasedSeats = 0; subscriptionItemOptionsList.Add(new SubscriptionItemOptions { Id = subscriptionItem.Id, Price = priceId, Quantity = newPlanConfiguration.SeatsMinimum }); } } providerPlan.SeatMinimum = newPlanConfiguration.SeatsMinimum; await providerPlanRepository.ReplaceAsync(providerPlan); } } if (subscriptionItemOptionsList.Count > 0) { await stripeAdapter.SubscriptionUpdateAsync(command.GatewaySubscriptionId, new SubscriptionUpdateOptions { Items = subscriptionItemOptionsList }); } } private Func CurrySeatScalingUpdate( Provider provider, ProviderPlan providerPlan, int newlyAssignedSeats) => async (currentlySubscribedSeats, newlySubscribedSeats) => { var plan = await pricingClient.GetPlanOrThrow(providerPlan.PlanType); await paymentService.AdjustSeats( provider, plan, currentlySubscribedSeats, newlySubscribedSeats); var newlyPurchasedSeats = newlySubscribedSeats > providerPlan.SeatMinimum ? newlySubscribedSeats - providerPlan.SeatMinimum : 0; providerPlan.PurchasedSeats = newlyPurchasedSeats; providerPlan.AllocatedSeats = newlyAssignedSeats; await providerPlanRepository.ReplaceAsync(providerPlan); }; // TODO: Replace with SPROC private async Task GetAssignedSeatTotalAsync(Provider provider, PlanType planType) { var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(provider.Id); var plan = await pricingClient.GetPlanOrThrow(planType); return providerOrganizations .Where(providerOrganization => providerOrganization.Plan == plan.Name && providerOrganization.Status == OrganizationStatusType.Managed) .Sum(providerOrganization => providerOrganization.Seats ?? 0); } // TODO: Replace with SPROC private async Task GetProviderPlanAsync(Provider provider, PlanType planType) { var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); var providerPlan = providerPlans.FirstOrDefault(x => x.PlanType == planType); if (providerPlan == null || !providerPlan.IsConfigured()) { throw new BillingException(message: "Provider plan is missing or misconfigured"); } return providerPlan; } private async Task GetManagedPlanTypeAsync( Provider provider, Organization organization) { if (provider.Type == ProviderType.MultiOrganizationEnterprise) { return (await providerPlanRepository.GetByProviderId(provider.Id)).First().PlanType; } return organization.PlanType switch { var planType when PlanConstants.TeamsPlanTypes.Contains(planType) => PlanType.TeamsMonthly, var planType when PlanConstants.EnterprisePlanTypes.Contains(planType) => PlanType.EnterpriseMonthly, _ => throw new BillingException() }; } }