1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-05 05:00:19 -05:00

[AC-1753] Automatically assign provider's pricing to new organizations (#3513)

* Initial commit

* resolve pr comment

* adding some unit test

* Resolve pr comments

* Adding some unit test

* Resolve pr comment

* changes to find the bug

* revert back changes on admin

* Fix the failing Test

* fix the bug
This commit is contained in:
cyprain-okeke 2023-12-20 22:54:45 +01:00 committed by GitHub
parent 5785905103
commit 75cae907e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 206 additions and 14 deletions

View File

@ -1,4 +1,7 @@
using Bit.Core.AdminConsole.Entities.Provider; using System.ComponentModel.DataAnnotations;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Provider;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
@ -33,13 +36,14 @@ public class ProviderService : IProviderService
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly IOrganizationService _organizationService; private readonly IOrganizationService _organizationService;
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly IStripeAdapter _stripeAdapter;
public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository,
IUserService userService, IOrganizationService organizationService, IMailService mailService, IUserService userService, IOrganizationService organizationService, IMailService mailService,
IDataProtectionProvider dataProtectionProvider, IEventService eventService, IDataProtectionProvider dataProtectionProvider, IEventService eventService,
IOrganizationRepository organizationRepository, GlobalSettings globalSettings, IOrganizationRepository organizationRepository, GlobalSettings globalSettings,
ICurrentContext currentContext) ICurrentContext currentContext, IStripeAdapter stripeAdapter)
{ {
_providerRepository = providerRepository; _providerRepository = providerRepository;
_providerUserRepository = providerUserRepository; _providerUserRepository = providerUserRepository;
@ -53,6 +57,7 @@ public class ProviderService : IProviderService
_globalSettings = globalSettings; _globalSettings = globalSettings;
_dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); _dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector");
_currentContext = currentContext; _currentContext = currentContext;
_stripeAdapter = stripeAdapter;
} }
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key)
@ -369,6 +374,7 @@ public class ProviderService : IProviderService
Key = key, Key = key,
}; };
await ApplyProviderPriceRateAsync(organizationId, providerId);
await _providerOrganizationRepository.CreateAsync(providerOrganization); await _providerOrganizationRepository.CreateAsync(providerOrganization);
await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added); await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added);
} }
@ -381,18 +387,110 @@ public class ProviderService : IProviderService
throw new BadRequestException("Provider must be of type Reseller in order to assign Organizations to it."); throw new BadRequestException("Provider must be of type Reseller in order to assign Organizations to it.");
} }
var existingProviderOrganizationsCount = await _providerOrganizationRepository.GetCountByOrganizationIdsAsync(organizationIds); var orgIdsList = organizationIds.ToList();
var existingProviderOrganizationsCount = await _providerOrganizationRepository.GetCountByOrganizationIdsAsync(orgIdsList);
if (existingProviderOrganizationsCount > 0) if (existingProviderOrganizationsCount > 0)
{ {
throw new BadRequestException("Organizations must not be assigned to any Provider."); throw new BadRequestException("Organizations must not be assigned to any Provider.");
} }
var providerOrganizationsToInsert = organizationIds.Select(orgId => new ProviderOrganization { ProviderId = providerId, OrganizationId = orgId }); var providerOrganizationsToInsert = orgIdsList.Select(orgId => new ProviderOrganization { ProviderId = providerId, OrganizationId = orgId });
var insertedProviderOrganizations = await _providerOrganizationRepository.CreateManyAsync(providerOrganizationsToInsert); var insertedProviderOrganizations = await _providerOrganizationRepository.CreateManyAsync(providerOrganizationsToInsert);
await _eventService.LogProviderOrganizationEventsAsync(insertedProviderOrganizations.Select(ipo => (ipo, EventType.ProviderOrganization_Added, (DateTime?)null))); await _eventService.LogProviderOrganizationEventsAsync(insertedProviderOrganizations.Select(ipo => (ipo, EventType.ProviderOrganization_Added, (DateTime?)null)));
} }
private async Task ApplyProviderPriceRateAsync(Guid organizationId, Guid providerId)
{
var provider = await _providerRepository.GetByIdAsync(providerId);
// if a provider was created before Nov 6, 2023.If true, the organization plan assigned to that provider is updated to a 2020 plan.
if (provider.CreationDate >= Constants.ProviderCreatedPriorNov62023)
{
return;
}
var organization = await _organizationRepository.GetByIdAsync(organizationId);
var subscriptionItem = await GetSubscriptionItemAsync(organization.GatewaySubscriptionId, GetStripeSeatPlanId(organization.PlanType));
var extractedPlanType = PlanTypeMappings(organization);
if (subscriptionItem != null)
{
await UpdateSubscriptionAsync(subscriptionItem, GetStripeSeatPlanId(extractedPlanType), organization);
}
await _organizationRepository.UpsertAsync(organization);
}
private async Task<Stripe.SubscriptionItem> GetSubscriptionItemAsync(string subscriptionId, string oldPlanId)
{
var subscriptionDetails = await _stripeAdapter.SubscriptionGetAsync(subscriptionId);
return subscriptionDetails.Items.Data.FirstOrDefault(item => item.Price.Id == oldPlanId);
}
private static string GetStripeSeatPlanId(PlanType planType)
{
return StaticStore.GetPlan(planType).PasswordManager.StripeSeatPlanId;
}
private async Task UpdateSubscriptionAsync(Stripe.SubscriptionItem subscriptionItem, string extractedPlanType, Organization organization)
{
try
{
if (subscriptionItem.Price.Id != extractedPlanType)
{
await _stripeAdapter.SubscriptionUpdateAsync(subscriptionItem.Subscription,
new Stripe.SubscriptionUpdateOptions
{
Items = new List<Stripe.SubscriptionItemOptions>
{
new()
{
Id = subscriptionItem.Id,
Price = extractedPlanType,
Quantity = organization.Seats.Value,
},
}
});
}
}
catch (Exception)
{
throw new Exception("Unable to update existing plan on stripe");
}
}
private static PlanType PlanTypeMappings(Organization organization)
{
var planTypeMappings = new Dictionary<PlanType, string>
{
{ PlanType.EnterpriseAnnually2020, GetEnumDisplayName(PlanType.EnterpriseAnnually2020) },
{ PlanType.EnterpriseMonthly2020, GetEnumDisplayName(PlanType.EnterpriseMonthly2020) },
{ PlanType.TeamsMonthly2020, GetEnumDisplayName(PlanType.TeamsMonthly2020) },
{ PlanType.TeamsAnnually2020, GetEnumDisplayName(PlanType.TeamsAnnually2020) }
};
foreach (var mapping in planTypeMappings)
{
if (mapping.Value.IndexOf(organization.Plan, StringComparison.Ordinal) != -1)
{
organization.PlanType = mapping.Key;
organization.Plan = mapping.Value;
return organization.PlanType;
}
}
throw new ArgumentException("Invalid PlanType selected");
}
private static string GetEnumDisplayName(Enum value)
{
var fieldInfo = value.GetType().GetField(value.ToString());
var displayAttribute = (DisplayAttribute)Attribute.GetCustomAttribute(fieldInfo!, typeof(DisplayAttribute));
return displayAttribute?.Name ?? value.ToString();
}
public async Task<ProviderOrganization> CreateOrganizationAsync(Guid providerId, public async Task<ProviderOrganization> CreateOrganizationAsync(Guid providerId,
OrganizationSignup organizationSignup, string clientOwnerEmail, User user) OrganizationSignup organizationSignup, string clientOwnerEmail, User user)
{ {

View File

@ -18,6 +18,7 @@ using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.DataProtection;
using NSubstitute; using NSubstitute;
using NSubstitute.ReturnsExtensions; using NSubstitute.ReturnsExtensions;
using Stripe;
using Xunit; using Xunit;
using Provider = Bit.Core.AdminConsole.Entities.Provider.Provider; using Provider = Bit.Core.AdminConsole.Entities.Provider.Provider;
using ProviderUser = Bit.Core.AdminConsole.Entities.Provider.ProviderUser; using ProviderUser = Bit.Core.AdminConsole.Entities.Provider.ProviderUser;
@ -598,4 +599,98 @@ public class ProviderServiceTests
await sutProvider.GetDependency<IEventService>().Received() await sutProvider.GetDependency<IEventService>().Received()
.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed);
} }
[Theory, BitAutoData]
public async Task AddOrganization_CreateAfterNov162023_PlanTypeDoesNotUpdated(Provider provider, Organization organization, string key,
SutProvider<ProviderService> sutProvider)
{
provider.Type = ProviderType.Msp;
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
var expectedPlanType = PlanType.EnterpriseAnnually;
organization.PlanType = PlanType.EnterpriseAnnually;
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, key);
await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default);
await sutProvider.GetDependency<IEventService>()
.Received().LogProviderOrganizationEventAsync(Arg.Any<ProviderOrganization>(),
EventType.ProviderOrganization_Added);
Assert.Equal(organization.PlanType, expectedPlanType);
}
[Theory, BitAutoData]
public async Task AddOrganization_CreateBeforeNov162023_PlanTypeUpdated(Provider provider, Organization organization, string key,
SutProvider<ProviderService> sutProvider)
{
var newCreationDate = DateTime.UtcNow.AddMonths(-3);
BackdateProviderCreationDate(provider, newCreationDate);
provider.Type = ProviderType.Msp;
organization.PlanType = PlanType.EnterpriseAnnually;
organization.Plan = "Enterprise (Annually)";
var expectedPlanType = PlanType.EnterpriseAnnually2020;
var expectedPlanId = "2020-enterprise-org-seat-annually";
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
providerOrganizationRepository.GetByOrganizationId(organization.Id).ReturnsNull();
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
var subscriptionItem = GetSubscription(organization.GatewaySubscriptionId);
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(organization.GatewaySubscriptionId)
.Returns(GetSubscription(organization.GatewaySubscriptionId));
await sutProvider.GetDependency<IStripeAdapter>().SubscriptionUpdateAsync(
organization.GatewaySubscriptionId, SubscriptionUpdateRequest(expectedPlanId, subscriptionItem));
await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, key);
await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default);
await sutProvider.GetDependency<IEventService>()
.Received().LogProviderOrganizationEventAsync(Arg.Any<ProviderOrganization>(),
EventType.ProviderOrganization_Added);
Assert.Equal(organization.PlanType, expectedPlanType);
}
private static SubscriptionUpdateOptions SubscriptionUpdateRequest(string expectedPlanId, Subscription subscriptionItem) =>
new()
{
Items = new List<Stripe.SubscriptionItemOptions>
{
new() { Id = subscriptionItem.Id, Price = expectedPlanId },
}
};
private static Subscription GetSubscription(string subscriptionId) =>
new()
{
Id = subscriptionId,
Items = new StripeList<SubscriptionItem>
{
Data = new List<SubscriptionItem>
{
new()
{
Id = "sub_item_123",
Price = new Price()
{
Id = "2023-enterprise-org-seat-annually"
}
}
}
}
};
private static void BackdateProviderCreationDate(Provider provider, DateTime newCreationDate)
{
// Set the CreationDate to the desired value
provider.GetType().GetProperty("CreationDate")?.SetValue(provider, newCreationDate, null);
}
} }

View File

@ -21,6 +21,7 @@ public class ProviderResponseModel : ResponseModel
BusinessCountry = provider.BusinessCountry; BusinessCountry = provider.BusinessCountry;
BusinessTaxNumber = provider.BusinessTaxNumber; BusinessTaxNumber = provider.BusinessTaxNumber;
BillingEmail = provider.BillingEmail; BillingEmail = provider.BillingEmail;
CreationDate = provider.CreationDate;
} }
public Guid Id { get; set; } public Guid Id { get; set; }
@ -32,4 +33,5 @@ public class ProviderResponseModel : ResponseModel
public string BusinessCountry { get; set; } public string BusinessCountry { get; set; }
public string BusinessTaxNumber { get; set; } public string BusinessTaxNumber { get; set; }
public string BillingEmail { get; set; } public string BillingEmail { get; set; }
public DateTime CreationDate { get; set; }
} }

View File

@ -1890,11 +1890,6 @@ public class OrganizationService : IOrganizationService
public void ValidatePasswordManagerPlan(Models.StaticStore.Plan plan, OrganizationUpgrade upgrade) public void ValidatePasswordManagerPlan(Models.StaticStore.Plan plan, OrganizationUpgrade upgrade)
{ {
if (plan is not { LegacyYear: null })
{
throw new BadRequestException("Invalid Password Manager plan selected.");
}
ValidatePlan(plan, upgrade.AdditionalSeats, "Password Manager"); ValidatePlan(plan, upgrade.AdditionalSeats, "Password Manager");
if (plan.PasswordManager.BaseSeats + upgrade.AdditionalSeats <= 0) if (plan.PasswordManager.BaseSeats + upgrade.AdditionalSeats <= 0)
@ -2409,12 +2404,8 @@ public class OrganizationService : IOrganizationService
public async Task CreatePendingOrganization(Organization organization, string ownerEmail, ClaimsPrincipal user, IUserService userService, bool salesAssistedTrialStarted) public async Task CreatePendingOrganization(Organization organization, string ownerEmail, ClaimsPrincipal user, IUserService userService, bool salesAssistedTrialStarted)
{ {
var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType);
if (plan is not { LegacyYear: null })
{
throw new BadRequestException("Invalid plan selected.");
}
if (plan.Disabled) if (plan!.Disabled)
{ {
throw new BadRequestException("Plan not found."); throw new BadRequestException("Plan not found.");
} }

View File

@ -29,6 +29,12 @@ public static class Constants
/// Used by IdentityServer to identify our own provider. /// Used by IdentityServer to identify our own provider.
/// </summary> /// </summary>
public const string IdentityProvider = "bitwarden"; public const string IdentityProvider = "bitwarden";
/// <summary>
/// Date identifier used in ProviderService to determine if a provider was created before Nov 6, 2023.
/// If true, the organization plan assigned to that provider is updated to a 2020 plan.
/// </summary>
public static readonly DateTime ProviderCreatedPriorNov62023 = new DateTime(2023, 11, 6);
} }
public static class AuthConstants public static class AuthConstants