1
0
mirror of https://github.com/bitwarden/server.git synced 2025-07-01 08:02:49 -05:00

[AC-2551] Consolidated Billing Migration (#4616)

* Move existing Billing SQL files into dbo folder

I noticed that every other team had a nested dbo folder under their team folder while Billing did not. This change replicates that.

* Add SQL files for ClientOrganizationMigrationRecord table

* Add SQL Server migration for ClientOrganizationMigrationRecord table

* Add ClientOrganizationMigrationRecord entity and repository interface

* Add ClientOrganizationMigrationRecord Dapper repository

* Add ClientOrganizationMigrationRecord EF repository

* Add EF migrations for ClientOrganizationMigrationRecord table

* Implement migration process

* Wire up new Admin tool to migrate providers

* Run dotnet format

* Updated coupon and credit application per product request

* AC-3057-3058: Fix expiration date and enabled from webhook processing

* Run dotnet format

* AC-3059: Fix assigned seats during migration

* Updated AllocatedSeats in the case plan already exists

* Update migration scripts to reflect current date
This commit is contained in:
Alex Morask
2024-10-04 10:55:00 -04:00
committed by GitHub
parent 738febf031
commit 0496085c39
63 changed files with 10418 additions and 3 deletions

View File

@ -0,0 +1,32 @@
using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Utilities;
#nullable enable
namespace Bit.Core.Billing.Entities;
public class ClientOrganizationMigrationRecord : ITableObject<Guid>
{
public Guid Id { get; set; }
public Guid OrganizationId { get; set; }
public Guid ProviderId { get; set; }
public PlanType PlanType { get; set; }
public int Seats { get; set; }
public short? MaxStorageGb { get; set; }
[MaxLength(50)] public string GatewayCustomerId { get; set; } = null!;
[MaxLength(50)] public string GatewaySubscriptionId { get; set; } = null!;
public DateTime? ExpirationDate { get; set; }
public int? MaxAutoscaleSeats { get; set; }
public OrganizationStatusType Status { get; set; }
public void SetNewId()
{
if (Id == default)
{
Id = CoreHelpers.GenerateComb();
}
}
}

View File

@ -0,0 +1,23 @@
namespace Bit.Core.Billing.Migration.Models;
public enum ClientMigrationProgress
{
Started = 1,
MigrationRecordCreated = 2,
SubscriptionEnded = 3,
Completed = 4,
Reversing = 5,
ResetOrganization = 6,
RecreatedSubscription = 7,
RemovedMigrationRecord = 8,
Reversed = 9
}
public class ClientMigrationTracker
{
public Guid ProviderId { get; set; }
public Guid OrganizationId { get; set; }
public string OrganizationName { get; set; }
public ClientMigrationProgress Progress { get; set; } = ClientMigrationProgress.Started;
}

View File

@ -0,0 +1,45 @@
using Bit.Core.Billing.Entities;
namespace Bit.Core.Billing.Migration.Models;
public class ProviderMigrationResult
{
public Guid ProviderId { get; set; }
public string ProviderName { get; set; }
public string Result { get; set; }
public List<ClientMigrationResult> Clients { get; set; }
}
public class ClientMigrationResult
{
public Guid OrganizationId { get; set; }
public string OrganizationName { get; set; }
public string Result { get; set; }
public ClientPreviousState PreviousState { get; set; }
}
public class ClientPreviousState
{
public ClientPreviousState() { }
public ClientPreviousState(ClientOrganizationMigrationRecord migrationRecord)
{
PlanType = migrationRecord.PlanType.ToString();
Seats = migrationRecord.Seats;
MaxStorageGb = migrationRecord.MaxStorageGb;
GatewayCustomerId = migrationRecord.GatewayCustomerId;
GatewaySubscriptionId = migrationRecord.GatewaySubscriptionId;
ExpirationDate = migrationRecord.ExpirationDate;
MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats;
Status = migrationRecord.Status.ToString();
}
public string PlanType { get; set; }
public int Seats { get; set; }
public short? MaxStorageGb { get; set; }
public string GatewayCustomerId { get; set; } = null!;
public string GatewaySubscriptionId { get; set; } = null!;
public DateTime? ExpirationDate { get; set; }
public int? MaxAutoscaleSeats { get; set; }
public string Status { get; set; }
}

View File

@ -0,0 +1,25 @@
namespace Bit.Core.Billing.Migration.Models;
public enum ProviderMigrationProgress
{
Started = 1,
ClientsMigrated = 2,
TeamsPlanConfigured = 3,
EnterprisePlanConfigured = 4,
CustomerSetup = 5,
SubscriptionSetup = 6,
CreditApplied = 7,
Completed = 8,
Reversing = 9,
ReversedClientMigrations = 10,
RemovedProviderPlans = 11
}
public class ProviderMigrationTracker
{
public Guid ProviderId { get; set; }
public string ProviderName { get; set; }
public List<Guid> OrganizationIds { get; set; }
public ProviderMigrationProgress Progress { get; set; } = ProviderMigrationProgress.Started;
}

View File

@ -0,0 +1,15 @@
using Bit.Core.Billing.Migration.Services;
using Bit.Core.Billing.Migration.Services.Implementations;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Billing.Migration;
public static class ServiceCollectionExtensions
{
public static void AddProviderMigration(this IServiceCollection services)
{
services.AddTransient<IMigrationTrackerCache, MigrationTrackerDistributedCache>();
services.AddTransient<IOrganizationMigrator, OrganizationMigrator>();
services.AddTransient<IProviderMigrator, ProviderMigrator>();
}
}

View File

@ -0,0 +1,17 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Migration.Models;
namespace Bit.Core.Billing.Migration.Services;
public interface IMigrationTrackerCache
{
Task StartTracker(Provider provider);
Task SetOrganizationIds(Guid providerId, IEnumerable<Guid> organizationIds);
Task<ProviderMigrationTracker> GetTracker(Guid providerId);
Task UpdateTrackingStatus(Guid providerId, ProviderMigrationProgress status);
Task StartTracker(Guid providerId, Organization organization);
Task<ClientMigrationTracker> GetTracker(Guid providerId, Guid organizationId);
Task UpdateTrackingStatus(Guid providerId, Guid organizationId, ClientMigrationProgress status);
}

View File

@ -0,0 +1,8 @@
using Bit.Core.AdminConsole.Entities;
namespace Bit.Core.Billing.Migration.Services;
public interface IOrganizationMigrator
{
Task Migrate(Guid providerId, Organization organization);
}

View File

@ -0,0 +1,10 @@
using Bit.Core.Billing.Migration.Models;
namespace Bit.Core.Billing.Migration.Services;
public interface IProviderMigrator
{
Task Migrate(Guid providerId);
Task<ProviderMigrationResult> GetResult(Guid providerId);
}

View File

@ -0,0 +1,107 @@
using System.Text.Json;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Migration.Models;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Billing.Migration.Services.Implementations;
public class MigrationTrackerDistributedCache(
[FromKeyedServices("persistent")]
IDistributedCache distributedCache) : IMigrationTrackerCache
{
public async Task StartTracker(Provider provider) =>
await SetAsync(new ProviderMigrationTracker
{
ProviderId = provider.Id,
ProviderName = provider.Name
});
public async Task SetOrganizationIds(Guid providerId, IEnumerable<Guid> organizationIds)
{
var tracker = await GetAsync(providerId);
tracker.OrganizationIds = organizationIds.ToList();
await SetAsync(tracker);
}
public Task<ProviderMigrationTracker> GetTracker(Guid providerId) => GetAsync(providerId);
public async Task UpdateTrackingStatus(Guid providerId, ProviderMigrationProgress status)
{
var tracker = await GetAsync(providerId);
tracker.Progress = status;
await SetAsync(tracker);
}
public async Task StartTracker(Guid providerId, Organization organization) =>
await SetAsync(new ClientMigrationTracker
{
ProviderId = providerId,
OrganizationId = organization.Id,
OrganizationName = organization.Name
});
public Task<ClientMigrationTracker> GetTracker(Guid providerId, Guid organizationId) =>
GetAsync(providerId, organizationId);
public async Task UpdateTrackingStatus(Guid providerId, Guid organizationId, ClientMigrationProgress status)
{
var tracker = await GetAsync(providerId, organizationId);
tracker.Progress = status;
await SetAsync(tracker);
}
private static string GetProviderCacheKey(Guid providerId) => $"provider_{providerId}_migration";
private static string GetClientCacheKey(Guid providerId, Guid clientId) =>
$"provider_{providerId}_client_{clientId}_migration";
private async Task<ProviderMigrationTracker> GetAsync(Guid providerId)
{
var cacheKey = GetProviderCacheKey(providerId);
var json = await distributedCache.GetStringAsync(cacheKey);
return string.IsNullOrEmpty(json) ? null : JsonSerializer.Deserialize<ProviderMigrationTracker>(json);
}
private async Task<ClientMigrationTracker> GetAsync(Guid providerId, Guid organizationId)
{
var cacheKey = GetClientCacheKey(providerId, organizationId);
var json = await distributedCache.GetStringAsync(cacheKey);
return string.IsNullOrEmpty(json) ? null : JsonSerializer.Deserialize<ClientMigrationTracker>(json);
}
private async Task SetAsync(ProviderMigrationTracker tracker)
{
var cacheKey = GetProviderCacheKey(tracker.ProviderId);
var json = JsonSerializer.Serialize(tracker);
await distributedCache.SetStringAsync(cacheKey, json, new DistributedCacheEntryOptions
{
SlidingExpiration = TimeSpan.FromMinutes(30)
});
}
private async Task SetAsync(ClientMigrationTracker tracker)
{
var cacheKey = GetClientCacheKey(tracker.ProviderId, tracker.OrganizationId);
var json = JsonSerializer.Serialize(tracker);
await distributedCache.SetStringAsync(cacheKey, json, new DistributedCacheEntryOptions
{
SlidingExpiration = TimeSpan.FromMinutes(30)
});
}
}

View File

@ -0,0 +1,326 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Migration.Models;
using Bit.Core.Billing.Repositories;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using Stripe;
using Plan = Bit.Core.Models.StaticStore.Plan;
namespace Bit.Core.Billing.Migration.Services.Implementations;
public class OrganizationMigrator(
IClientOrganizationMigrationRecordRepository clientOrganizationMigrationRecordRepository,
ILogger<OrganizationMigrator> logger,
IMigrationTrackerCache migrationTrackerCache,
IOrganizationRepository organizationRepository,
IStripeAdapter stripeAdapter) : IOrganizationMigrator
{
private const string _cancellationComment = "Cancelled as part of provider migration to Consolidated Billing";
public async Task Migrate(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Starting migration for organization ({OrganizationID})", organization.Id);
await migrationTrackerCache.StartTracker(providerId, organization);
await CreateMigrationRecordAsync(providerId, organization);
await CancelSubscriptionAsync(providerId, organization);
await UpdateOrganizationAsync(providerId, organization);
}
#region Steps
private async Task CreateMigrationRecordAsync(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Creating ClientOrganizationMigrationRecord for organization ({OrganizationID})", organization.Id);
var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id);
if (migrationRecord != null)
{
logger.LogInformation(
"CB: ClientOrganizationMigrationRecord already exists for organization ({OrganizationID}), deleting record",
organization.Id);
await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord);
}
await clientOrganizationMigrationRecordRepository.CreateAsync(new ClientOrganizationMigrationRecord
{
OrganizationId = organization.Id,
ProviderId = providerId,
PlanType = organization.PlanType,
Seats = organization.Seats ?? 0,
MaxStorageGb = organization.MaxStorageGb,
GatewayCustomerId = organization.GatewayCustomerId!,
GatewaySubscriptionId = organization.GatewaySubscriptionId!,
ExpirationDate = organization.ExpirationDate,
MaxAutoscaleSeats = organization.MaxAutoscaleSeats,
Status = organization.Status
});
logger.LogInformation("CB: Created migration record for organization ({OrganizationID})", organization.Id);
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
ClientMigrationProgress.MigrationRecordCreated);
}
private async Task CancelSubscriptionAsync(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Cancelling subscription for organization ({OrganizationID})", organization.Id);
var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId);
if (subscription is
{
Status:
StripeConstants.SubscriptionStatus.Active or
StripeConstants.SubscriptionStatus.PastDue or
StripeConstants.SubscriptionStatus.Trialing
})
{
await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId,
new SubscriptionUpdateOptions { CancelAtPeriodEnd = false });
subscription = await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId,
new SubscriptionCancelOptions
{
CancellationDetails = new SubscriptionCancellationDetailsOptions
{
Comment = _cancellationComment
},
InvoiceNow = true,
Prorate = true,
Expand = ["latest_invoice", "test_clock"]
});
logger.LogInformation("CB: Cancelled subscription for organization ({OrganizationID})", organization.Id);
var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow;
var trialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now;
if (!trialing && subscription is { Status: StripeConstants.SubscriptionStatus.Canceled, CancellationDetails.Comment: _cancellationComment })
{
var latestInvoice = subscription.LatestInvoice;
if (latestInvoice.Status == "draft")
{
await stripeAdapter.InvoiceFinalizeInvoiceAsync(latestInvoice.Id,
new InvoiceFinalizeOptions { AutoAdvance = true });
logger.LogInformation("CB: Finalized prorated invoice for organization ({OrganizationID})", organization.Id);
}
}
}
else
{
logger.LogInformation(
"CB: Did not need to cancel subscription for organization ({OrganizationID}) as it was inactive",
organization.Id);
}
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
ClientMigrationProgress.SubscriptionEnded);
}
private async Task UpdateOrganizationAsync(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Bringing organization ({OrganizationID}) under provider management",
organization.Id);
var plan = StaticStore.GetPlan(organization.Plan.Contains("Teams") ? PlanType.TeamsMonthly : PlanType.EnterpriseMonthly);
ResetOrganizationPlan(organization, plan);
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.GatewaySubscriptionId = null;
organization.ExpirationDate = null;
organization.MaxAutoscaleSeats = null;
organization.Status = OrganizationStatusType.Managed;
await organizationRepository.ReplaceAsync(organization);
logger.LogInformation("CB: Brought organization ({OrganizationID}) under provider management",
organization.Id);
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
ClientMigrationProgress.Completed);
}
#endregion
#region Reverse
private async Task RemoveMigrationRecordAsync(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Removing migration record for organization ({OrganizationID})", organization.Id);
var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id);
if (migrationRecord != null)
{
await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord);
logger.LogInformation(
"CB: Removed migration record for organization ({OrganizationID})",
organization.Id);
}
else
{
logger.LogInformation("CB: Did not remove migration record for organization ({OrganizationID}) as it does not exist", organization.Id);
}
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, ClientMigrationProgress.Reversed);
}
private async Task RecreateSubscriptionAsync(Guid providerId, Organization organization)
{
logger.LogInformation("CB: Recreating subscription for organization ({OrganizationID})", organization.Id);
if (!string.IsNullOrEmpty(organization.GatewaySubscriptionId))
{
if (string.IsNullOrEmpty(organization.GatewayCustomerId))
{
logger.LogError(
"CB: Cannot recreate subscription for organization ({OrganizationID}) as it does not have a Stripe customer",
organization.Id);
throw new Exception();
}
var customer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId,
new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] });
var collectionMethod =
customer.DefaultSource != null ||
customer.InvoiceSettings?.DefaultPaymentMethod != null ||
customer.Metadata.ContainsKey(Utilities.BraintreeCustomerIdKey)
? StripeConstants.CollectionMethod.ChargeAutomatically
: StripeConstants.CollectionMethod.SendInvoice;
var plan = StaticStore.GetPlan(organization.PlanType);
var items = new List<SubscriptionItemOptions>
{
new ()
{
Price = plan.PasswordManager.StripeSeatPlanId,
Quantity = organization.Seats
}
};
if (organization.MaxStorageGb.HasValue && plan.PasswordManager.BaseStorageGb.HasValue && organization.MaxStorageGb.Value > plan.PasswordManager.BaseStorageGb.Value)
{
var additionalStorage = organization.MaxStorageGb.Value - plan.PasswordManager.BaseStorageGb.Value;
items.Add(new SubscriptionItemOptions
{
Price = plan.PasswordManager.StripeStoragePlanId,
Quantity = additionalStorage
});
}
var subscriptionCreateOptions = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
Customer = customer.Id,
CollectionMethod = collectionMethod,
DaysUntilDue = collectionMethod == StripeConstants.CollectionMethod.SendInvoice ? 30 : null,
Items = items,
Metadata = new Dictionary<string, string>
{
[organization.GatewayIdField()] = organization.Id.ToString()
},
OffSession = true,
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
TrialPeriodDays = plan.TrialPeriodDays
};
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
organization.GatewaySubscriptionId = subscription.Id;
await organizationRepository.ReplaceAsync(organization);
logger.LogInformation("CB: Recreated subscription for organization ({OrganizationID})", organization.Id);
}
else
{
logger.LogInformation(
"CB: Did not recreate subscription for organization ({OrganizationID}) as it already exists",
organization.Id);
}
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
ClientMigrationProgress.RecreatedSubscription);
}
private async Task ReverseOrganizationUpdateAsync(Guid providerId, Organization organization)
{
var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id);
if (migrationRecord == null)
{
logger.LogError(
"CB: Cannot reverse migration for organization ({OrganizationID}) as it does not have a migration record",
organization.Id);
throw new Exception();
}
var plan = StaticStore.GetPlan(migrationRecord.PlanType);
ResetOrganizationPlan(organization, plan);
organization.MaxStorageGb = migrationRecord.MaxStorageGb;
organization.ExpirationDate = migrationRecord.ExpirationDate;
organization.MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats;
organization.Status = migrationRecord.Status;
await organizationRepository.ReplaceAsync(organization);
logger.LogInformation("CB: Reversed organization ({OrganizationID}) updates",
organization.Id);
await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id,
ClientMigrationProgress.ResetOrganization);
}
#endregion
#region Shared
private static void ResetOrganizationPlan(Organization organization, Plan plan)
{
organization.Plan = plan.Name;
organization.PlanType = plan.Type;
organization.MaxCollections = plan.PasswordManager.MaxCollections;
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.UsePolicies = plan.HasPolicies;
organization.UseSso = plan.HasSso;
organization.UseGroups = plan.HasGroups;
organization.UseEvents = plan.HasEvents;
organization.UseDirectory = plan.HasDirectory;
organization.UseTotp = plan.HasTotp;
organization.Use2fa = plan.Has2fa;
organization.UseApi = plan.HasApi;
organization.UseResetPassword = plan.HasResetPassword;
organization.SelfHost = plan.HasSelfHost;
organization.UsersGetPremium = plan.UsersGetPremium;
organization.UseCustomPermissions = plan.HasCustomPermissions;
organization.UseScim = plan.HasScim;
organization.UseKeyConnector = plan.HasKeyConnector;
}
#endregion
}

View File

@ -0,0 +1,385 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Migration.Models;
using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using Stripe;
namespace Bit.Core.Billing.Migration.Services.Implementations;
public class ProviderMigrator(
IClientOrganizationMigrationRecordRepository clientOrganizationMigrationRecordRepository,
IOrganizationMigrator organizationMigrator,
ILogger<ProviderMigrator> logger,
IMigrationTrackerCache migrationTrackerCache,
IOrganizationRepository organizationRepository,
IPaymentService paymentService,
IProviderBillingService providerBillingService,
IProviderOrganizationRepository providerOrganizationRepository,
IProviderRepository providerRepository,
IProviderPlanRepository providerPlanRepository,
IStripeAdapter stripeAdapter) : IProviderMigrator
{
public async Task Migrate(Guid providerId)
{
var provider = await GetProviderAsync(providerId);
if (provider == null)
{
return;
}
logger.LogInformation("CB: Starting migration for provider ({ProviderID})", providerId);
await migrationTrackerCache.StartTracker(provider);
await MigrateClientsAsync(providerId);
await ConfigureTeamsPlanAsync(providerId);
await ConfigureEnterprisePlanAsync(providerId);
await SetupCustomerAsync(provider);
await SetupSubscriptionAsync(provider);
await ApplyCreditAsync(provider);
await UpdateProviderAsync(provider);
}
public async Task<ProviderMigrationResult> GetResult(Guid providerId)
{
var providerTracker = await migrationTrackerCache.GetTracker(providerId);
if (providerTracker == null)
{
return null;
}
var clientTrackers = await Task.WhenAll(providerTracker.OrganizationIds.Select(organizationId =>
migrationTrackerCache.GetTracker(providerId, organizationId)));
var migrationRecordLookup = new Dictionary<Guid, ClientOrganizationMigrationRecord>();
foreach (var clientTracker in clientTrackers)
{
var migrationRecord =
await clientOrganizationMigrationRecordRepository.GetByOrganizationId(clientTracker.OrganizationId);
migrationRecordLookup.Add(clientTracker.OrganizationId, migrationRecord);
}
return new ProviderMigrationResult
{
ProviderId = providerTracker.ProviderId,
ProviderName = providerTracker.ProviderName,
Result = providerTracker.Progress.ToString(),
Clients = clientTrackers.Select(tracker =>
{
var foundMigrationRecord = migrationRecordLookup.TryGetValue(tracker.OrganizationId, out var migrationRecord);
return new ClientMigrationResult
{
OrganizationId = tracker.OrganizationId,
OrganizationName = tracker.OrganizationName,
Result = tracker.Progress.ToString(),
PreviousState = foundMigrationRecord ? new ClientPreviousState(migrationRecord) : null
};
}).ToList(),
};
}
#region Steps
private async Task MigrateClientsAsync(Guid providerId)
{
logger.LogInformation("CB: Migrating clients for provider ({ProviderID})", providerId);
var organizations = await GetEnabledClientsAsync(providerId);
var organizationIds = organizations.Select(organization => organization.Id);
await migrationTrackerCache.SetOrganizationIds(providerId, organizationIds);
foreach (var organization in organizations)
{
var tracker = await migrationTrackerCache.GetTracker(providerId, organization.Id);
if (tracker is not { Progress: ClientMigrationProgress.Completed })
{
await organizationMigrator.Migrate(providerId, organization);
}
}
logger.LogInformation("CB: Migrated clients for provider ({ProviderID})", providerId);
await migrationTrackerCache.UpdateTrackingStatus(providerId,
ProviderMigrationProgress.ClientsMigrated);
}
private async Task ConfigureTeamsPlanAsync(Guid providerId)
{
logger.LogInformation("CB: Configuring Teams plan for provider ({ProviderID})", providerId);
var organizations = await GetEnabledClientsAsync(providerId);
var teamsSeats = organizations
.Where(IsTeams)
.Sum(client => client.Seats) ?? 0;
var teamsProviderPlan = (await providerPlanRepository.GetByProviderId(providerId))
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly);
if (teamsProviderPlan == null)
{
await providerPlanRepository.CreateAsync(new ProviderPlan
{
ProviderId = providerId,
PlanType = PlanType.TeamsMonthly,
SeatMinimum = teamsSeats,
PurchasedSeats = 0,
AllocatedSeats = teamsSeats
});
logger.LogInformation("CB: Created Teams plan for provider ({ProviderID}) with a seat minimum of {Seats}",
providerId, teamsSeats);
}
else
{
logger.LogInformation("CB: Teams plan already exists for provider ({ProviderID}), updating seat minimum", providerId);
teamsProviderPlan.SeatMinimum = teamsSeats;
teamsProviderPlan.AllocatedSeats = teamsSeats;
await providerPlanRepository.ReplaceAsync(teamsProviderPlan);
logger.LogInformation("CB: Updated Teams plan for provider ({ProviderID}) to seat minimum of {Seats}",
providerId, teamsProviderPlan.SeatMinimum);
}
await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.TeamsPlanConfigured);
}
private async Task ConfigureEnterprisePlanAsync(Guid providerId)
{
logger.LogInformation("CB: Configuring Enterprise plan for provider ({ProviderID})", providerId);
var organizations = await GetEnabledClientsAsync(providerId);
var enterpriseSeats = organizations
.Where(IsEnterprise)
.Sum(client => client.Seats) ?? 0;
var enterpriseProviderPlan = (await providerPlanRepository.GetByProviderId(providerId))
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly);
if (enterpriseProviderPlan == null)
{
await providerPlanRepository.CreateAsync(new ProviderPlan
{
ProviderId = providerId,
PlanType = PlanType.EnterpriseMonthly,
SeatMinimum = enterpriseSeats,
PurchasedSeats = 0,
AllocatedSeats = enterpriseSeats
});
logger.LogInformation("CB: Created Enterprise plan for provider ({ProviderID}) with a seat minimum of {Seats}",
providerId, enterpriseSeats);
}
else
{
logger.LogInformation("CB: Enterprise plan already exists for provider ({ProviderID}), updating seat minimum", providerId);
enterpriseProviderPlan.SeatMinimum = enterpriseSeats;
enterpriseProviderPlan.AllocatedSeats = enterpriseSeats;
await providerPlanRepository.ReplaceAsync(enterpriseProviderPlan);
logger.LogInformation("CB: Updated Enterprise plan for provider ({ProviderID}) to seat minimum of {Seats}",
providerId, enterpriseProviderPlan.SeatMinimum);
}
await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.EnterprisePlanConfigured);
}
private async Task SetupCustomerAsync(Provider provider)
{
if (string.IsNullOrEmpty(provider.GatewayCustomerId))
{
var organizations = await GetEnabledClientsAsync(provider.Id);
var sampleOrganization = organizations.FirstOrDefault(organization => !string.IsNullOrEmpty(organization.GatewayCustomerId));
if (sampleOrganization == null)
{
logger.LogInformation(
"CB: Could not find sample organization for provider ({ProviderID}) that has a Stripe customer",
provider.Id);
return;
}
var taxInfo = await paymentService.GetTaxInfoAsync(sampleOrganization);
var customer = await providerBillingService.SetupCustomer(provider, taxInfo);
await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
Coupon = StripeConstants.CouponIDs.MSPDiscount35
});
provider.GatewayCustomerId = customer.Id;
await providerRepository.ReplaceAsync(provider);
logger.LogInformation("CB: Setup Stripe customer for provider ({ProviderID})", provider.Id);
}
else
{
logger.LogInformation("CB: Stripe customer already exists for provider ({ProviderID})", provider.Id);
}
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CustomerSetup);
}
private async Task SetupSubscriptionAsync(Provider provider)
{
if (string.IsNullOrEmpty(provider.GatewaySubscriptionId))
{
if (!string.IsNullOrEmpty(provider.GatewayCustomerId))
{
var subscription = await providerBillingService.SetupSubscription(provider);
provider.GatewaySubscriptionId = subscription.Id;
await providerRepository.ReplaceAsync(provider);
logger.LogInformation("CB: Setup Stripe subscription for provider ({ProviderID})", provider.Id);
}
else
{
logger.LogInformation(
"CB: Could not set up Stripe subscription for provider ({ProviderID}) with no Stripe customer",
provider.Id);
return;
}
}
else
{
logger.LogInformation("CB: Stripe subscription already exists for provider ({ProviderID})", provider.Id);
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
var enterpriseSeatMinimum = providerPlans
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly)?
.SeatMinimum ?? 0;
var teamsSeatMinimum = providerPlans
.FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly)?
.SeatMinimum ?? 0;
await providerBillingService.UpdateSeatMinimums(provider, enterpriseSeatMinimum, teamsSeatMinimum);
logger.LogInformation(
"CB: Updated Stripe subscription for provider ({ProviderID}) with current seat minimums", provider.Id);
}
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.SubscriptionSetup);
}
private async Task ApplyCreditAsync(Provider provider)
{
var organizations = await GetEnabledClientsAsync(provider.Id);
var organizationCustomers =
await Task.WhenAll(organizations.Select(organization => stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId)));
var organizationCancellationCredit = organizationCustomers.Sum(customer => customer.Balance);
var legacyOrganizations = organizations.Where(organization =>
organization.PlanType is
PlanType.EnterpriseAnnually2020 or
PlanType.EnterpriseMonthly2020 or
PlanType.TeamsAnnually2020 or
PlanType.TeamsMonthly2020);
var legacyOrganizationCredit = legacyOrganizations.Sum(organization => organization.Seats ?? 0);
await stripeAdapter.CustomerUpdateAsync(provider.GatewayCustomerId, new CustomerUpdateOptions
{
Balance = organizationCancellationCredit + legacyOrganizationCredit
});
logger.LogInformation("CB: Applied {Credit} credit to provider ({ProviderID})", organizationCancellationCredit, provider.Id);
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CreditApplied);
}
private async Task UpdateProviderAsync(Provider provider)
{
provider.Status = ProviderStatusType.Billable;
await providerRepository.ReplaceAsync(provider);
logger.LogInformation("CB: Completed migration for provider ({ProviderID})", provider.Id);
await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.Completed);
}
#endregion
#region Utilities
private async Task<List<Organization>> GetEnabledClientsAsync(Guid providerId)
{
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId);
return (await Task.WhenAll(providerOrganizations.Select(providerOrganization =>
organizationRepository.GetByIdAsync(providerOrganization.OrganizationId))))
.Where(organization => organization.Enabled)
.ToList();
}
private async Task<Provider> GetProviderAsync(Guid providerId)
{
var provider = await providerRepository.GetByIdAsync(providerId);
if (provider == null)
{
logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it does not exist", providerId);
return null;
}
if (provider.Type != ProviderType.Msp)
{
logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it is not an MSP", providerId);
return null;
}
if (provider.Status == ProviderStatusType.Created)
{
return provider;
}
logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it is not in the 'Created' state", providerId);
return null;
}
private static bool IsEnterprise(Organization organization) => organization.Plan.Contains("Enterprise");
private static bool IsTeams(Organization organization) => organization.Plan.Contains("Teams");
#endregion
}

View File

@ -0,0 +1,10 @@
using Bit.Core.Billing.Entities;
using Bit.Core.Repositories;
namespace Bit.Core.Billing.Repositories;
public interface IClientOrganizationMigrationRecordRepository : IRepository<ClientOrganizationMigrationRecord, Guid>
{
Task<ClientOrganizationMigrationRecord> GetByOrganizationId(Guid organizationId);
Task<ICollection<ClientOrganizationMigrationRecord>> GetByProviderId(Guid providerId);
}