1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-05 05:00:19 -05:00

Merge branch 'main' into jmccannon/ac/pm-16811-scim-invite-optimization

# Conflicts:
#	src/Api/Startup.cs
#	src/Core/Services/IPaymentService.cs
#	src/Core/Services/Implementations/StripePaymentService.cs
This commit is contained in:
jrmccannon 2025-04-03 07:57:04 -05:00
commit dda7906d83
No known key found for this signature in database
GPG Key ID: CF03F3DB01CE96A6
80 changed files with 4465 additions and 738 deletions

View File

@ -3,7 +3,7 @@
<PropertyGroup> <PropertyGroup>
<TargetFramework>net8.0</TargetFramework> <TargetFramework>net8.0</TargetFramework>
<Version>2025.3.3</Version> <Version>2025.4.0</Version>
<RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace> <RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace>
<ImplicitUsings>enable</ImplicitUsings> <ImplicitUsings>enable</ImplicitUsings>

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.Providers.Interfaces; using Bit.Core.AdminConsole.Providers.Interfaces;
@ -7,10 +8,12 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Microsoft.Extensions.DependencyInjection;
using Stripe; using Stripe;
namespace Bit.Commercial.Core.AdminConsole.Providers; namespace Bit.Commercial.Core.AdminConsole.Providers;
@ -28,6 +31,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
private readonly ISubscriberService _subscriberService; private readonly ISubscriberService _subscriberService;
private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery; private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
private readonly IAutomaticTaxStrategy _automaticTaxStrategy;
public RemoveOrganizationFromProviderCommand( public RemoveOrganizationFromProviderCommand(
IEventService eventService, IEventService eventService,
@ -40,7 +44,8 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
IProviderBillingService providerBillingService, IProviderBillingService providerBillingService,
ISubscriberService subscriberService, ISubscriberService subscriberService,
IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery,
IPricingClient pricingClient) IPricingClient pricingClient,
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
{ {
_eventService = eventService; _eventService = eventService;
_mailService = mailService; _mailService = mailService;
@ -53,6 +58,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
_subscriberService = subscriberService; _subscriberService = subscriberService;
_hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery; _hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery;
_pricingClient = pricingClient; _pricingClient = pricingClient;
_automaticTaxStrategy = automaticTaxStrategy;
} }
public async Task RemoveOrganizationFromProvider( public async Task RemoveOrganizationFromProvider(
@ -107,10 +113,11 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
organization.IsValidClient() && organization.IsValidClient() &&
!string.IsNullOrEmpty(organization.GatewayCustomerId)) !string.IsNullOrEmpty(organization.GatewayCustomerId))
{ {
await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions var customer = await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions
{ {
Description = string.Empty, Description = string.Empty,
Email = organization.BillingEmail Email = organization.BillingEmail,
Expand = ["tax", "tax_ids"]
}); });
var plan = await _pricingClient.GetPlanOrThrow(organization.PlanType); var plan = await _pricingClient.GetPlanOrThrow(organization.PlanType);
@ -120,7 +127,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
Customer = organization.GatewayCustomerId, Customer = organization.GatewayCustomerId,
CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, CollectionMethod = StripeConstants.CollectionMethod.SendInvoice,
DaysUntilDue = 30, DaysUntilDue = 30,
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
Metadata = new Dictionary<string, string> Metadata = new Dictionary<string, string>
{ {
{ "organizationId", organization.Id.ToString() } { "organizationId", organization.Id.ToString() }
@ -130,6 +136,18 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }] Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }]
}; };
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
_automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
{
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
}
var subscription = await _stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); var subscription = await _stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
organization.GatewaySubscriptionId = subscription.Id; organization.GatewaySubscriptionId = subscription.Id;

View File

@ -14,6 +14,7 @@ using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
@ -22,6 +23,7 @@ using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using CsvHelper; using CsvHelper;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Stripe; using Stripe;
@ -29,10 +31,10 @@ namespace Bit.Commercial.Core.Billing;
public class ProviderBillingService( public class ProviderBillingService(
IEventService eventService, IEventService eventService,
IFeatureService featureService,
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
ILogger<ProviderBillingService> logger, ILogger<ProviderBillingService> logger,
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
IPaymentService paymentService,
IPricingClient pricingClient, IPricingClient pricingClient,
IProviderInvoiceItemRepository providerInvoiceItemRepository, IProviderInvoiceItemRepository providerInvoiceItemRepository,
IProviderOrganizationRepository providerOrganizationRepository, IProviderOrganizationRepository providerOrganizationRepository,
@ -40,7 +42,9 @@ public class ProviderBillingService(
IProviderUserRepository providerUserRepository, IProviderUserRepository providerUserRepository,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService, ISubscriberService subscriberService,
ITaxService taxService) : IProviderBillingService ITaxService taxService,
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
: IProviderBillingService
{ {
[RequireFeature(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal)] [RequireFeature(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal)]
public async Task AddExistingOrganization( public async Task AddExistingOrganization(
@ -143,36 +147,29 @@ public class ProviderBillingService(
public async Task ChangePlan(ChangeProviderPlanCommand command) public async Task ChangePlan(ChangeProviderPlanCommand command)
{ {
var plan = await providerPlanRepository.GetByIdAsync(command.ProviderPlanId); var (provider, providerPlanId, newPlanType) = command;
if (plan == null) var providerPlan = await providerPlanRepository.GetByIdAsync(providerPlanId);
if (providerPlan == null)
{ {
throw new BadRequestException("Provider plan not found."); throw new BadRequestException("Provider plan not found.");
} }
if (plan.PlanType == command.NewPlan) if (providerPlan.PlanType == newPlanType)
{ {
return; return;
} }
var oldPlanConfiguration = await pricingClient.GetPlanOrThrow(plan.PlanType); var subscription = await subscriberService.GetSubscriptionOrThrow(provider);
var newPlanConfiguration = await pricingClient.GetPlanOrThrow(command.NewPlan);
plan.PlanType = command.NewPlan; var oldPriceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType);
await providerPlanRepository.ReplaceAsync(plan); var newPriceId = ProviderPriceAdapter.GetPriceId(provider, subscription, newPlanType);
Subscription subscription; providerPlan.PlanType = newPlanType;
try await providerPlanRepository.ReplaceAsync(providerPlan);
{
subscription = await stripeAdapter.ProviderSubscriptionGetAsync(command.GatewaySubscriptionId, plan.ProviderId);
}
catch (InvalidOperationException)
{
throw new ConflictException("Subscription not found.");
}
var oldSubscriptionItem = subscription.Items.SingleOrDefault(x => var oldSubscriptionItem = subscription.Items.SingleOrDefault(x => x.Price.Id == oldPriceId);
x.Price.Id == oldPlanConfiguration.PasswordManager.StripeProviderPortalSeatPlanId);
var updateOptions = new SubscriptionUpdateOptions var updateOptions = new SubscriptionUpdateOptions
{ {
@ -180,7 +177,7 @@ public class ProviderBillingService(
[ [
new SubscriptionItemOptions new SubscriptionItemOptions
{ {
Price = newPlanConfiguration.PasswordManager.StripeProviderPortalSeatPlanId, Price = newPriceId,
Quantity = oldSubscriptionItem!.Quantity Quantity = oldSubscriptionItem!.Quantity
}, },
new SubscriptionItemOptions new SubscriptionItemOptions
@ -191,12 +188,14 @@ public class ProviderBillingService(
] ]
}; };
await stripeAdapter.SubscriptionUpdateAsync(command.GatewaySubscriptionId, updateOptions); await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, updateOptions);
// Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId) // Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId)
// 1. Retrieve PlanType and PlanName for ProviderPlan // 1. Retrieve PlanType and PlanName for ProviderPlan
// 2. Assign PlanType & PlanName to Organization // 2. Assign PlanType & PlanName to Organization
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(plan.ProviderId); var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId);
var newPlan = await pricingClient.GetPlanOrThrow(newPlanType);
foreach (var providerOrganization in providerOrganizations) foreach (var providerOrganization in providerOrganizations)
{ {
@ -205,8 +204,8 @@ public class ProviderBillingService(
{ {
throw new ConflictException($"Organization '{providerOrganization.Id}' not found."); throw new ConflictException($"Organization '{providerOrganization.Id}' not found.");
} }
organization.PlanType = command.NewPlan; organization.PlanType = newPlanType;
organization.Plan = newPlanConfiguration.Name; organization.Plan = newPlan.Name;
await organizationRepository.ReplaceAsync(organization); await organizationRepository.ReplaceAsync(organization);
} }
} }
@ -400,7 +399,7 @@ public class ProviderBillingService(
var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment; var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment;
var update = CurrySeatScalingUpdate( var scaleQuantityTo = CurrySeatScalingUpdate(
provider, provider,
providerPlan, providerPlan,
newlyAssignedSeatTotal); newlyAssignedSeatTotal);
@ -423,9 +422,7 @@ public class ProviderBillingService(
else if (currentlyAssignedSeatTotal <= seatMinimum && else if (currentlyAssignedSeatTotal <= seatMinimum &&
newlyAssignedSeatTotal > seatMinimum) newlyAssignedSeatTotal > seatMinimum)
{ {
await update( await scaleQuantityTo(newlyAssignedSeatTotal);
seatMinimum,
newlyAssignedSeatTotal);
} }
/* /*
* Above the limit => Above the limit: * Above the limit => Above the limit:
@ -434,9 +431,7 @@ public class ProviderBillingService(
else if (currentlyAssignedSeatTotal > seatMinimum && else if (currentlyAssignedSeatTotal > seatMinimum &&
newlyAssignedSeatTotal > seatMinimum) newlyAssignedSeatTotal > seatMinimum)
{ {
await update( await scaleQuantityTo(newlyAssignedSeatTotal);
currentlyAssignedSeatTotal,
newlyAssignedSeatTotal);
} }
/* /*
* Above the limit => Below the limit: * Above the limit => Below the limit:
@ -445,9 +440,7 @@ public class ProviderBillingService(
else if (currentlyAssignedSeatTotal > seatMinimum && else if (currentlyAssignedSeatTotal > seatMinimum &&
newlyAssignedSeatTotal <= seatMinimum) newlyAssignedSeatTotal <= seatMinimum)
{ {
await update( await scaleQuantityTo(seatMinimum);
currentlyAssignedSeatTotal,
seatMinimum);
} }
} }
@ -557,7 +550,8 @@ public class ProviderBillingService(
{ {
ArgumentNullException.ThrowIfNull(provider); ArgumentNullException.ThrowIfNull(provider);
var customer = await subscriberService.GetCustomerOrThrow(provider); var customerGetOptions = new CustomerGetOptions { Expand = ["tax", "tax_ids"] };
var customer = await subscriberService.GetCustomerOrThrow(provider, customerGetOptions);
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
@ -580,19 +574,17 @@ public class ProviderBillingService(
throw new BillingException(); throw new BillingException();
} }
var priceId = ProviderPriceAdapter.GetActivePriceId(provider, providerPlan.PlanType);
subscriptionItemOptionsList.Add(new SubscriptionItemOptions subscriptionItemOptionsList.Add(new SubscriptionItemOptions
{ {
Price = plan.PasswordManager.StripeProviderPortalSeatPlanId, Price = priceId,
Quantity = providerPlan.SeatMinimum Quantity = providerPlan.SeatMinimum
}); });
} }
var subscriptionCreateOptions = new SubscriptionCreateOptions var subscriptionCreateOptions = new SubscriptionCreateOptions
{ {
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, CollectionMethod = StripeConstants.CollectionMethod.SendInvoice,
Customer = customer.Id, Customer = customer.Id,
DaysUntilDue = 30, DaysUntilDue = 30,
@ -605,6 +597,15 @@ public class ProviderBillingService(
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations
}; };
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
}
try try
{ {
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
@ -643,43 +644,37 @@ public class ProviderBillingService(
public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command) public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command)
{ {
if (command.Configuration.Any(x => x.SeatsMinimum < 0)) var (provider, updatedPlanConfigurations) = command;
if (updatedPlanConfigurations.Any(x => x.SeatsMinimum < 0))
{ {
throw new BadRequestException("Provider seat minimums must be at least 0."); throw new BadRequestException("Provider seat minimums must be at least 0.");
} }
Subscription subscription; var subscription = await subscriberService.GetSubscriptionOrThrow(provider);
try
{
subscription = await stripeAdapter.ProviderSubscriptionGetAsync(command.GatewaySubscriptionId, command.Id);
}
catch (InvalidOperationException)
{
throw new ConflictException("Subscription not found.");
}
var subscriptionItemOptionsList = new List<SubscriptionItemOptions>(); var subscriptionItemOptionsList = new List<SubscriptionItemOptions>();
var providerPlans = await providerPlanRepository.GetByProviderId(command.Id); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
foreach (var newPlanConfiguration in command.Configuration) foreach (var updatedPlanConfiguration in updatedPlanConfigurations)
{ {
var (updatedPlanType, updatedSeatMinimum) = updatedPlanConfiguration;
var providerPlan = var providerPlan =
providerPlans.Single(providerPlan => providerPlan.PlanType == newPlanConfiguration.Plan); providerPlans.Single(providerPlan => providerPlan.PlanType == updatedPlanType);
if (providerPlan.SeatMinimum != newPlanConfiguration.SeatsMinimum) if (providerPlan.SeatMinimum != updatedSeatMinimum)
{ {
var newPlan = await pricingClient.GetPlanOrThrow(newPlanConfiguration.Plan); var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, updatedPlanType);
var priceId = newPlan.PasswordManager.StripeProviderPortalSeatPlanId;
var subscriptionItem = subscription.Items.First(item => item.Price.Id == priceId); var subscriptionItem = subscription.Items.First(item => item.Price.Id == priceId);
if (providerPlan.PurchasedSeats == 0) if (providerPlan.PurchasedSeats == 0)
{ {
if (providerPlan.AllocatedSeats > newPlanConfiguration.SeatsMinimum) if (providerPlan.AllocatedSeats > updatedSeatMinimum)
{ {
providerPlan.PurchasedSeats = providerPlan.AllocatedSeats - newPlanConfiguration.SeatsMinimum; providerPlan.PurchasedSeats = providerPlan.AllocatedSeats - updatedSeatMinimum;
subscriptionItemOptionsList.Add(new SubscriptionItemOptions subscriptionItemOptionsList.Add(new SubscriptionItemOptions
{ {
@ -694,7 +689,7 @@ public class ProviderBillingService(
{ {
Id = subscriptionItem.Id, Id = subscriptionItem.Id,
Price = priceId, Price = priceId,
Quantity = newPlanConfiguration.SeatsMinimum Quantity = updatedSeatMinimum
}); });
} }
} }
@ -702,9 +697,9 @@ public class ProviderBillingService(
{ {
var totalSeats = providerPlan.SeatMinimum + providerPlan.PurchasedSeats; var totalSeats = providerPlan.SeatMinimum + providerPlan.PurchasedSeats;
if (newPlanConfiguration.SeatsMinimum <= totalSeats) if (updatedSeatMinimum <= totalSeats)
{ {
providerPlan.PurchasedSeats = totalSeats - newPlanConfiguration.SeatsMinimum; providerPlan.PurchasedSeats = totalSeats - updatedSeatMinimum;
} }
else else
{ {
@ -713,12 +708,12 @@ public class ProviderBillingService(
{ {
Id = subscriptionItem.Id, Id = subscriptionItem.Id,
Price = priceId, Price = priceId,
Quantity = newPlanConfiguration.SeatsMinimum Quantity = updatedSeatMinimum
}); });
} }
} }
providerPlan.SeatMinimum = newPlanConfiguration.SeatsMinimum; providerPlan.SeatMinimum = updatedSeatMinimum;
await providerPlanRepository.ReplaceAsync(providerPlan); await providerPlanRepository.ReplaceAsync(providerPlan);
} }
@ -726,23 +721,33 @@ public class ProviderBillingService(
if (subscriptionItemOptionsList.Count > 0) if (subscriptionItemOptionsList.Count > 0)
{ {
await stripeAdapter.SubscriptionUpdateAsync(command.GatewaySubscriptionId, await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
new SubscriptionUpdateOptions { Items = subscriptionItemOptionsList }); new SubscriptionUpdateOptions { Items = subscriptionItemOptionsList });
} }
} }
private Func<int, int, Task> CurrySeatScalingUpdate( private Func<int, Task> CurrySeatScalingUpdate(
Provider provider, Provider provider,
ProviderPlan providerPlan, ProviderPlan providerPlan,
int newlyAssignedSeats) => async (currentlySubscribedSeats, newlySubscribedSeats) => int newlyAssignedSeats) => async newlySubscribedSeats =>
{ {
var plan = await pricingClient.GetPlanOrThrow(providerPlan.PlanType); var subscription = await subscriberService.GetSubscriptionOrThrow(provider);
await paymentService.AdjustSeats( var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType);
provider,
plan, var item = subscription.Items.First(item => item.Price.Id == priceId);
currentlySubscribedSeats,
newlySubscribedSeats); await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions
{
Items = [
new SubscriptionItemOptions
{
Id = item.Id,
Price = priceId,
Quantity = newlySubscribedSeats
}
]
});
var newlyPurchasedSeats = newlySubscribedSeats > providerPlan.SeatMinimum var newlyPurchasedSeats = newlySubscribedSeats > providerPlan.SeatMinimum
? newlySubscribedSeats - providerPlan.SeatMinimum ? newlySubscribedSeats - providerPlan.SeatMinimum

View File

@ -0,0 +1,133 @@
// ReSharper disable SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault
#nullable enable
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing;
using Bit.Core.Billing.Enums;
using Stripe;
namespace Bit.Commercial.Core.Billing;
public static class ProviderPriceAdapter
{
public static class MSP
{
public static class Active
{
public const string Enterprise = "provider-portal-enterprise-monthly-2025";
public const string Teams = "provider-portal-teams-monthly-2025";
}
public static class Legacy
{
public const string Enterprise = "password-manager-provider-portal-enterprise-monthly-2024";
public const string Teams = "password-manager-provider-portal-teams-monthly-2024";
public static readonly List<string> List = [Enterprise, Teams];
}
}
public static class BusinessUnit
{
public static class Active
{
public const string Annually = "business-unit-portal-enterprise-annually-2025";
public const string Monthly = "business-unit-portal-enterprise-monthly-2025";
}
public static class Legacy
{
public const string Annually = "password-manager-provider-portal-enterprise-annually-2024";
public const string Monthly = "password-manager-provider-portal-enterprise-monthly-2024";
public static readonly List<string> List = [Annually, Monthly];
}
}
/// <summary>
/// Uses the <paramref name="provider"/>'s <see cref="Provider.Type"/> and <paramref name="subscription"/> to determine
/// whether the <paramref name="provider"/> is on active or legacy pricing and then returns a Stripe price ID for the provided
/// <paramref name="planType"/> based on that determination.
/// </summary>
/// <param name="provider">The provider to get the Stripe price ID for.</param>
/// <param name="subscription">The provider's subscription.</param>
/// <param name="planType">The plan type correlating to the desired Stripe price ID.</param>
/// <returns>A Stripe <see cref="Stripe.Price"/> ID.</returns>
/// <exception cref="BillingException">Thrown when the provider's type is not <see cref="ProviderType.Msp"/> or <see cref="ProviderType.MultiOrganizationEnterprise"/>.</exception>
/// <exception cref="BillingException">Thrown when the provided <see cref="planType"/> does not relate to a Stripe price ID.</exception>
public static string GetPriceId(
Provider provider,
Subscription subscription,
PlanType planType)
{
var priceIds = subscription.Items.Select(item => item.Price.Id);
var invalidPlanType =
new BillingException(message: $"PlanType {planType} does not have an associated provider price in Stripe");
return provider.Type switch
{
ProviderType.Msp => MSP.Legacy.List.Intersect(priceIds).Any()
? planType switch
{
PlanType.TeamsMonthly => MSP.Legacy.Teams,
PlanType.EnterpriseMonthly => MSP.Legacy.Enterprise,
_ => throw invalidPlanType
}
: planType switch
{
PlanType.TeamsMonthly => MSP.Active.Teams,
PlanType.EnterpriseMonthly => MSP.Active.Enterprise,
_ => throw invalidPlanType
},
ProviderType.MultiOrganizationEnterprise => BusinessUnit.Legacy.List.Intersect(priceIds).Any()
? planType switch
{
PlanType.EnterpriseAnnually => BusinessUnit.Legacy.Annually,
PlanType.EnterpriseMonthly => BusinessUnit.Legacy.Monthly,
_ => throw invalidPlanType
}
: planType switch
{
PlanType.EnterpriseAnnually => BusinessUnit.Active.Annually,
PlanType.EnterpriseMonthly => BusinessUnit.Active.Monthly,
_ => throw invalidPlanType
},
_ => throw new BillingException(
$"ProviderType {provider.Type} does not have any associated provider price IDs")
};
}
/// <summary>
/// Uses the <paramref name="provider"/>'s <see cref="Provider.Type"/> to return the active Stripe price ID for the provided
/// <paramref name="planType"/>.
/// </summary>
/// <param name="provider">The provider to get the Stripe price ID for.</param>
/// <param name="planType">The plan type correlating to the desired Stripe price ID.</param>
/// <returns>A Stripe <see cref="Stripe.Price"/> ID.</returns>
/// <exception cref="BillingException">Thrown when the provider's type is not <see cref="ProviderType.Msp"/> or <see cref="ProviderType.MultiOrganizationEnterprise"/>.</exception>
/// <exception cref="BillingException">Thrown when the provided <see cref="planType"/> does not relate to a Stripe price ID.</exception>
public static string GetActivePriceId(
Provider provider,
PlanType planType)
{
var invalidPlanType =
new BillingException(message: $"PlanType {planType} does not have an associated provider price in Stripe");
return provider.Type switch
{
ProviderType.Msp => planType switch
{
PlanType.TeamsMonthly => MSP.Active.Teams,
PlanType.EnterpriseMonthly => MSP.Active.Enterprise,
_ => throw invalidPlanType
},
ProviderType.MultiOrganizationEnterprise => planType switch
{
PlanType.EnterpriseAnnually => BusinessUnit.Active.Annually,
PlanType.EnterpriseMonthly => BusinessUnit.Active.Monthly,
_ => throw invalidPlanType
},
_ => throw new BillingException(
$"ProviderType {provider.Type} does not have any associated provider price IDs")
};
}
}

View File

@ -228,6 +228,26 @@ public class RemoveOrganizationFromProviderCommandTests
Id = "subscription_id" Id = "subscription_id"
}); });
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == organization.GatewayCustomerId &&
options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice &&
options.DaysUntilDue == 30 &&
options.Metadata["organizationId"] == organization.Id.ToString() &&
options.OffSession == true &&
options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId &&
options.Items.First().Quantity == organization.Seats)
, Arg.Any<Customer>()))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization);
await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options => await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options =>

View File

@ -4,6 +4,7 @@ using Bit.Commercial.Core.Billing;
using Bit.Commercial.Core.Billing.Models; using Bit.Commercial.Core.Billing.Models;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
@ -115,6 +116,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider) SutProvider<ProviderBillingService> sutProvider)
{ {
// Arrange // Arrange
provider.Type = ProviderType.MultiOrganizationEnterprise;
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>(); var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
var existingPlan = new ProviderPlan var existingPlan = new ProviderPlan
{ {
@ -132,10 +135,7 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(existingPlan.PlanType) sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(existingPlan.PlanType)
.Returns(StaticStore.GetPlan(existingPlan.PlanType)); .Returns(StaticStore.GetPlan(existingPlan.PlanType));
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider)
stripeAdapter.ProviderSubscriptionGetAsync(
Arg.Is(provider.GatewaySubscriptionId),
Arg.Is(provider.Id))
.Returns(new Subscription .Returns(new Subscription
{ {
Id = provider.GatewaySubscriptionId, Id = provider.GatewaySubscriptionId,
@ -158,7 +158,7 @@ public class ProviderBillingServiceTests
}); });
var command = var command =
new ChangeProviderPlanCommand(providerPlanId, PlanType.EnterpriseMonthly, provider.GatewaySubscriptionId); new ChangeProviderPlanCommand(provider, providerPlanId, PlanType.EnterpriseMonthly);
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(command.NewPlan) sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(command.NewPlan)
.Returns(StaticStore.GetPlan(command.NewPlan)); .Returns(StaticStore.GetPlan(command.NewPlan));
@ -170,6 +170,8 @@ public class ProviderBillingServiceTests
await providerPlanRepository.Received(1) await providerPlanRepository.Received(1)
.ReplaceAsync(Arg.Is<ProviderPlan>(p => p.PlanType == PlanType.EnterpriseMonthly)); .ReplaceAsync(Arg.Is<ProviderPlan>(p => p.PlanType == PlanType.EnterpriseMonthly));
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
await stripeAdapter.Received(1) await stripeAdapter.Received(1)
.SubscriptionUpdateAsync( .SubscriptionUpdateAsync(
Arg.Is(provider.GatewaySubscriptionId), Arg.Is(provider.GatewaySubscriptionId),
@ -405,6 +407,23 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans); sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans);
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } },
new SubscriptionItem
{
Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise }
}
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
// 50 seats currently assigned with a seat minimum of 100 // 50 seats currently assigned with a seat minimum of 100
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
@ -427,11 +446,9 @@ public class ProviderBillingServiceTests
await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10);
// 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum
await sutProvider.GetDependency<IPaymentService>().DidNotReceiveWithAnyArgs().AdjustSeats( await sutProvider.GetDependency<IStripeAdapter>().DidNotReceiveWithAnyArgs().SubscriptionUpdateAsync(
Arg.Any<Provider>(), Arg.Any<string>(),
Arg.Any<Bit.Core.Models.StaticStore.Plan>(), Arg.Any<SubscriptionUpdateOptions>());
Arg.Any<int>(),
Arg.Any<int>());
await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>( await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>(
pPlan => pPlan.AllocatedSeats == 60)); pPlan => pPlan.AllocatedSeats == 60));
@ -474,6 +491,23 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans); sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans);
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } },
new SubscriptionItem
{
Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise }
}
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
// 95 seats currently assigned with a seat minimum of 100 // 95 seats currently assigned with a seat minimum of 100
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
@ -496,11 +530,12 @@ public class ProviderBillingServiceTests
await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10);
// 95 current + 10 seat scale = 105 seats, 5 above the minimum // 95 current + 10 seat scale = 105 seats, 5 above the minimum
await sutProvider.GetDependency<IPaymentService>().Received(1).AdjustSeats( await sutProvider.GetDependency<IStripeAdapter>().Received(1).SubscriptionUpdateAsync(
provider, provider.GatewaySubscriptionId,
StaticStore.GetPlan(providerPlan.PlanType), Arg.Is<SubscriptionUpdateOptions>(
providerPlan.SeatMinimum!.Value, options =>
105); options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams &&
options.Items.First().Quantity == 105));
// 105 total seats - 100 minimum = 5 purchased seats // 105 total seats - 100 minimum = 5 purchased seats
await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>( await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>(
@ -544,6 +579,23 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans); sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans);
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } },
new SubscriptionItem
{
Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise }
}
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
// 110 seats currently assigned with a seat minimum of 100 // 110 seats currently assigned with a seat minimum of 100
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
@ -566,11 +618,12 @@ public class ProviderBillingServiceTests
await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10);
// 110 current + 10 seat scale up = 120 seats // 110 current + 10 seat scale up = 120 seats
await sutProvider.GetDependency<IPaymentService>().Received(1).AdjustSeats( await sutProvider.GetDependency<IStripeAdapter>().Received(1).SubscriptionUpdateAsync(
provider, provider.GatewaySubscriptionId,
StaticStore.GetPlan(providerPlan.PlanType), Arg.Is<SubscriptionUpdateOptions>(
110, options =>
120); options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams &&
options.Items.First().Quantity == 120));
// 120 total seats - 100 seat minimum = 20 purchased seats // 120 total seats - 100 seat minimum = 20 purchased seats
await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>( await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>(
@ -614,6 +667,23 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans); sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans);
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } },
new SubscriptionItem
{
Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise }
}
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
// 110 seats currently assigned with a seat minimum of 100 // 110 seats currently assigned with a seat minimum of 100
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
@ -636,11 +706,12 @@ public class ProviderBillingServiceTests
await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, -30); await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, -30);
// 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum. // 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum.
await sutProvider.GetDependency<IPaymentService>().Received(1).AdjustSeats( await sutProvider.GetDependency<IStripeAdapter>().Received(1).SubscriptionUpdateAsync(
provider, provider.GatewaySubscriptionId,
StaticStore.GetPlan(providerPlan.PlanType), Arg.Is<SubscriptionUpdateOptions>(
110, options =>
providerPlan.SeatMinimum!.Value); options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams &&
options.Items.First().Quantity == providerPlan.SeatMinimum!.Value));
// Being below the seat minimum means no purchased seats. // Being below the seat minimum means no purchased seats.
await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>( await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>(
@ -924,11 +995,15 @@ public class ProviderBillingServiceTests
{ {
provider.GatewaySubscriptionId = null; provider.GatewaySubscriptionId = null;
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider).Returns(new Customer sutProvider.GetDependency<ISubscriberService>()
{ .GetCustomerOrThrow(
Id = "customer_id", provider,
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")))
}); .Returns(new Customer
{
Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
});
var providerPlans = new List<ProviderPlan> var providerPlans = new List<ProviderPlan>
{ {
@ -973,13 +1048,18 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
Provider provider) Provider provider)
{ {
provider.Type = ProviderType.Msp;
provider.GatewaySubscriptionId = null; provider.GatewaySubscriptionId = null;
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider).Returns(new Customer var customer = new Customer
{ {
Id = "customer_id", Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}); };
sutProvider.GetDependency<ISubscriberService>()
.GetCustomerOrThrow(
provider,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer);
var providerPlans = new List<ProviderPlan> var providerPlans = new List<ProviderPlan>
{ {
@ -1012,11 +1092,21 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id) sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id)
.Returns(providerPlans); .Returns(providerPlans);
var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly);
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>( sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
sub => sub =>
sub.AutomaticTax.Enabled == true && sub.AutomaticTax.Enabled == true &&
@ -1024,9 +1114,9 @@ public class ProviderBillingServiceTests
sub.Customer == "customer_id" && sub.Customer == "customer_id" &&
sub.DaysUntilDue == 30 && sub.DaysUntilDue == 30 &&
sub.Items.Count == 2 && sub.Items.Count == 2 &&
sub.Items.ElementAt(0).Price == teamsPlan.PasswordManager.StripeProviderPortalSeatPlanId && sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams &&
sub.Items.ElementAt(0).Quantity == 100 && sub.Items.ElementAt(0).Quantity == 100 &&
sub.Items.ElementAt(1).Price == enterprisePlan.PasswordManager.StripeProviderPortalSeatPlanId && sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise &&
sub.Items.ElementAt(1).Quantity == 100 && sub.Items.ElementAt(1).Quantity == 100 &&
sub.Metadata["providerId"] == provider.Id.ToString() && sub.Metadata["providerId"] == provider.Id.ToString() &&
sub.OffSession == true && sub.OffSession == true &&
@ -1048,8 +1138,7 @@ public class ProviderBillingServiceTests
{ {
// Arrange // Arrange
var command = new UpdateProviderSeatMinimumsCommand( var command = new UpdateProviderSeatMinimumsCommand(
provider.Id, provider,
provider.GatewaySubscriptionId,
[ [
(PlanType.TeamsMonthly, -10), (PlanType.TeamsMonthly, -10),
(PlanType.EnterpriseMonthly, 50) (PlanType.EnterpriseMonthly, 50)
@ -1068,6 +1157,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider) SutProvider<ProviderBillingService> sutProvider)
{ {
// Arrange // Arrange
provider.Type = ProviderType.Msp;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>(); var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
@ -1097,9 +1188,7 @@ public class ProviderBillingServiceTests
} }
}; };
stripeAdapter.ProviderSubscriptionGetAsync( sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
provider.GatewaySubscriptionId,
provider.Id).Returns(subscription);
var providerPlans = new List<ProviderPlan> var providerPlans = new List<ProviderPlan>
{ {
@ -1116,8 +1205,7 @@ public class ProviderBillingServiceTests
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var command = new UpdateProviderSeatMinimumsCommand( var command = new UpdateProviderSeatMinimumsCommand(
provider.Id, provider,
provider.GatewaySubscriptionId,
[ [
(PlanType.EnterpriseMonthly, 30), (PlanType.EnterpriseMonthly, 30),
(PlanType.TeamsMonthly, 20) (PlanType.TeamsMonthly, 20)
@ -1149,6 +1237,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider) SutProvider<ProviderBillingService> sutProvider)
{ {
// Arrange // Arrange
provider.Type = ProviderType.Msp;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>(); var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
@ -1178,7 +1268,7 @@ public class ProviderBillingServiceTests
} }
}; };
stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
var providerPlans = new List<ProviderPlan> var providerPlans = new List<ProviderPlan>
{ {
@ -1195,8 +1285,7 @@ public class ProviderBillingServiceTests
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var command = new UpdateProviderSeatMinimumsCommand( var command = new UpdateProviderSeatMinimumsCommand(
provider.Id, provider,
provider.GatewaySubscriptionId,
[ [
(PlanType.EnterpriseMonthly, 70), (PlanType.EnterpriseMonthly, 70),
(PlanType.TeamsMonthly, 50) (PlanType.TeamsMonthly, 50)
@ -1228,6 +1317,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider) SutProvider<ProviderBillingService> sutProvider)
{ {
// Arrange // Arrange
provider.Type = ProviderType.Msp;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>(); var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
@ -1257,7 +1348,7 @@ public class ProviderBillingServiceTests
} }
}; };
stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
var providerPlans = new List<ProviderPlan> var providerPlans = new List<ProviderPlan>
{ {
@ -1274,8 +1365,7 @@ public class ProviderBillingServiceTests
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var command = new UpdateProviderSeatMinimumsCommand( var command = new UpdateProviderSeatMinimumsCommand(
provider.Id, provider,
provider.GatewaySubscriptionId,
[ [
(PlanType.EnterpriseMonthly, 60), (PlanType.EnterpriseMonthly, 60),
(PlanType.TeamsMonthly, 60) (PlanType.TeamsMonthly, 60)
@ -1301,6 +1391,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider) SutProvider<ProviderBillingService> sutProvider)
{ {
// Arrange // Arrange
provider.Type = ProviderType.Msp;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>(); var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
@ -1330,7 +1422,7 @@ public class ProviderBillingServiceTests
} }
}; };
stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
var providerPlans = new List<ProviderPlan> var providerPlans = new List<ProviderPlan>
{ {
@ -1347,8 +1439,7 @@ public class ProviderBillingServiceTests
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var command = new UpdateProviderSeatMinimumsCommand( var command = new UpdateProviderSeatMinimumsCommand(
provider.Id, provider,
provider.GatewaySubscriptionId,
[ [
(PlanType.EnterpriseMonthly, 80), (PlanType.EnterpriseMonthly, 80),
(PlanType.TeamsMonthly, 80) (PlanType.TeamsMonthly, 80)
@ -1380,6 +1471,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider) SutProvider<ProviderBillingService> sutProvider)
{ {
// Arrange // Arrange
provider.Type = ProviderType.Msp;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>(); var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
@ -1409,7 +1502,7 @@ public class ProviderBillingServiceTests
} }
}; };
stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription); sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
var providerPlans = new List<ProviderPlan> var providerPlans = new List<ProviderPlan>
{ {
@ -1426,8 +1519,7 @@ public class ProviderBillingServiceTests
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var command = new UpdateProviderSeatMinimumsCommand( var command = new UpdateProviderSeatMinimumsCommand(
provider.Id, provider,
provider.GatewaySubscriptionId,
[ [
(PlanType.EnterpriseMonthly, 70), (PlanType.EnterpriseMonthly, 70),
(PlanType.TeamsMonthly, 30) (PlanType.TeamsMonthly, 30)

View File

@ -0,0 +1,151 @@
using Bit.Commercial.Core.Billing;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing.Enums;
using Stripe;
using Xunit;
namespace Bit.Commercial.Core.Test.Billing;
public class ProviderPriceAdapterTests
{
[Theory]
[InlineData("password-manager-provider-portal-enterprise-monthly-2024", PlanType.EnterpriseMonthly)]
[InlineData("password-manager-provider-portal-teams-monthly-2024", PlanType.TeamsMonthly)]
public void GetPriceId_MSP_Legacy_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.Msp
};
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = priceId } }
]
}
};
var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType);
Assert.Equal(result, priceId);
}
[Theory]
[InlineData("provider-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)]
[InlineData("provider-portal-teams-monthly-2025", PlanType.TeamsMonthly)]
public void GetPriceId_MSP_Active_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.Msp
};
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = priceId } }
]
}
};
var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType);
Assert.Equal(result, priceId);
}
[Theory]
[InlineData("password-manager-provider-portal-enterprise-annually-2024", PlanType.EnterpriseAnnually)]
[InlineData("password-manager-provider-portal-enterprise-monthly-2024", PlanType.EnterpriseMonthly)]
public void GetPriceId_BusinessUnit_Legacy_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.MultiOrganizationEnterprise
};
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = priceId } }
]
}
};
var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType);
Assert.Equal(result, priceId);
}
[Theory]
[InlineData("business-unit-portal-enterprise-annually-2025", PlanType.EnterpriseAnnually)]
[InlineData("business-unit-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)]
public void GetPriceId_BusinessUnit_Active_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.MultiOrganizationEnterprise
};
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = priceId } }
]
}
};
var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType);
Assert.Equal(result, priceId);
}
[Theory]
[InlineData("provider-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)]
[InlineData("provider-portal-teams-monthly-2025", PlanType.TeamsMonthly)]
public void GetActivePriceId_MSP_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.Msp
};
var result = ProviderPriceAdapter.GetActivePriceId(provider, planType);
Assert.Equal(result, priceId);
}
[Theory]
[InlineData("business-unit-portal-enterprise-annually-2025", PlanType.EnterpriseAnnually)]
[InlineData("business-unit-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)]
public void GetActivePriceId_BusinessUnit_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.MultiOrganizationEnterprise
};
var result = ProviderPriceAdapter.GetActivePriceId(provider, planType);
Assert.Equal(result, priceId);
}
}

View File

@ -300,8 +300,7 @@ public class ProvidersController : Controller
{ {
case ProviderType.Msp: case ProviderType.Msp:
var updateMspSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand( var updateMspSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand(
provider.Id, provider,
provider.GatewaySubscriptionId,
[ [
(Plan: PlanType.TeamsMonthly, SeatsMinimum: model.TeamsMonthlySeatMinimum), (Plan: PlanType.TeamsMonthly, SeatsMinimum: model.TeamsMonthlySeatMinimum),
(Plan: PlanType.EnterpriseMonthly, SeatsMinimum: model.EnterpriseMonthlySeatMinimum) (Plan: PlanType.EnterpriseMonthly, SeatsMinimum: model.EnterpriseMonthlySeatMinimum)
@ -314,15 +313,14 @@ public class ProvidersController : Controller
// 1. Change the plan and take over any old values. // 1. Change the plan and take over any old values.
var changeMoePlanCommand = new ChangeProviderPlanCommand( var changeMoePlanCommand = new ChangeProviderPlanCommand(
provider,
existingMoePlan.Id, existingMoePlan.Id,
model.Plan!.Value, model.Plan!.Value);
provider.GatewaySubscriptionId);
await _providerBillingService.ChangePlan(changeMoePlanCommand); await _providerBillingService.ChangePlan(changeMoePlanCommand);
// 2. Update the seat minimums. // 2. Update the seat minimums.
var updateMoeSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand( var updateMoeSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand(
provider.Id, provider,
provider.GatewaySubscriptionId,
[ [
(Plan: model.Plan!.Value, SeatsMinimum: model.EnterpriseMinimumSeats!.Value) (Plan: model.Plan!.Value, SeatsMinimum: model.EnterpriseMinimumSeats!.Value)
]); ]);

View File

@ -13,7 +13,17 @@ public static class PolicyDetailResponses
{ {
throw new ArgumentException($"'{nameof(policy)}' must be of type '{nameof(PolicyType.SingleOrg)}'.", nameof(policy)); throw new ArgumentException($"'{nameof(policy)}' must be of type '{nameof(PolicyType.SingleOrg)}'.", nameof(policy));
} }
return new PolicyDetailResponseModel(policy, await CanToggleState());
return new PolicyDetailResponseModel(policy, !await hasVerifiedDomainsQuery.HasVerifiedDomainsAsync(policy.OrganizationId)); async Task<bool> CanToggleState()
{
if (!await hasVerifiedDomainsQuery.HasVerifiedDomainsAsync(policy.OrganizationId))
{
return true;
}
return !policy.Enabled;
}
} }
} }

View File

@ -76,6 +76,13 @@ public class OrganizationSponsorshipsController : Controller
public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model)
{ {
var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId);
var freeFamiliesSponsorshipPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(sponsoringOrgId,
PolicyType.FreeFamiliesSponsorshipPolicy);
if (freeFamiliesSponsorshipPolicy?.Enabled == true)
{
throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator.");
}
var sponsorship = await _createSponsorshipCommand.CreateSponsorshipAsync( var sponsorship = await _createSponsorshipCommand.CreateSponsorshipAsync(
sponsoringOrg, sponsoringOrg,
@ -89,6 +96,14 @@ public class OrganizationSponsorshipsController : Controller
[SelfHosted(NotSelfHostedOnly = true)] [SelfHosted(NotSelfHostedOnly = true)]
public async Task ResendSponsorshipOffer(Guid sponsoringOrgId) public async Task ResendSponsorshipOffer(Guid sponsoringOrgId)
{ {
var freeFamiliesSponsorshipPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(sponsoringOrgId,
PolicyType.FreeFamiliesSponsorshipPolicy);
if (freeFamiliesSponsorshipPolicy?.Enabled == true)
{
throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator.");
}
var sponsoringOrgUser = await _organizationUserRepository var sponsoringOrgUser = await _organizationUserRepository
.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default); .GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default);
@ -135,6 +150,14 @@ public class OrganizationSponsorshipsController : Controller
throw new BadRequestException("Can only redeem sponsorship for an organization you own."); throw new BadRequestException("Can only redeem sponsorship for an organization you own.");
} }
var freeFamiliesSponsorshipPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(
model.SponsoredOrganizationId, PolicyType.FreeFamiliesSponsorshipPolicy);
if (freeFamiliesSponsorshipPolicy?.Enabled == true)
{
throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator.");
}
await _setUpSponsorshipCommand.SetUpSponsorshipAsync( await _setUpSponsorshipCommand.SetUpSponsorshipAsync(
sponsorship, sponsorship,
await _organizationRepository.GetByIdAsync(model.SponsoredOrganizationId)); await _organizationRepository.GetByIdAsync(model.SponsoredOrganizationId));

View File

@ -1,6 +1,5 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using Bit.Api.Auth.Models.Request; using Bit.Api.Auth.Models.Request;
using Bit.Api.Auth.Models.Request.Accounts;
using Bit.Api.Models.Request; using Bit.Api.Models.Request;
using Bit.Api.Models.Response; using Bit.Api.Models.Response;
using Bit.Core.Auth.Models.Api.Request; using Bit.Core.Auth.Models.Api.Request;
@ -125,7 +124,7 @@ public class DevicesController : Controller
} }
[HttpPost("{identifier}/retrieve-keys")] [HttpPost("{identifier}/retrieve-keys")]
public async Task<ProtectedDeviceResponseModel> GetDeviceKeys(string identifier, [FromBody] SecretVerificationRequestModel model) public async Task<ProtectedDeviceResponseModel> GetDeviceKeys(string identifier)
{ {
var user = await _userService.GetUserByPrincipalAsync(User); var user = await _userService.GetUserByPrincipalAsync(User);
@ -134,14 +133,7 @@ public class DevicesController : Controller
throw new UnauthorizedAccessException(); throw new UnauthorizedAccessException();
} }
if (!await _userService.VerifySecretAsync(user, model.Secret))
{
await Task.Delay(2000);
throw new BadRequestException(string.Empty, "User verification failed.");
}
var device = await _deviceRepository.GetByIdentifierAsync(identifier, user.Id); var device = await _deviceRepository.GetByIdentifierAsync(identifier, user.Id);
if (device == null) if (device == null)
{ {
throw new NotFoundException(); throw new NotFoundException();

View File

@ -8,6 +8,7 @@ using Bit.Api.Tools.Models.Request;
using Bit.Api.Vault.Models.Request; using Bit.Api.Vault.Models.Request;
using Bit.Core; using Bit.Core;
using Bit.Core.Auth.Entities; using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
@ -43,6 +44,7 @@ public class AccountsKeyManagementController : Controller
_organizationUserValidator; _organizationUserValidator;
private readonly IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>> private readonly IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>>
_webauthnKeyValidator; _webauthnKeyValidator;
private readonly IRotationValidator<IEnumerable<OtherDeviceKeysUpdateRequestModel>, IEnumerable<Device>> _deviceValidator;
public AccountsKeyManagementController(IUserService userService, public AccountsKeyManagementController(IUserService userService,
IFeatureService featureService, IFeatureService featureService,
@ -57,7 +59,8 @@ public class AccountsKeyManagementController : Controller
emergencyAccessValidator, emergencyAccessValidator,
IRotationValidator<IEnumerable<ResetPasswordWithOrgIdRequestModel>, IReadOnlyList<OrganizationUser>> IRotationValidator<IEnumerable<ResetPasswordWithOrgIdRequestModel>, IReadOnlyList<OrganizationUser>>
organizationUserValidator, organizationUserValidator,
IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>> webAuthnKeyValidator) IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>> webAuthnKeyValidator,
IRotationValidator<IEnumerable<OtherDeviceKeysUpdateRequestModel>, IEnumerable<Device>> deviceValidator)
{ {
_userService = userService; _userService = userService;
_featureService = featureService; _featureService = featureService;
@ -71,6 +74,7 @@ public class AccountsKeyManagementController : Controller
_emergencyAccessValidator = emergencyAccessValidator; _emergencyAccessValidator = emergencyAccessValidator;
_organizationUserValidator = organizationUserValidator; _organizationUserValidator = organizationUserValidator;
_webauthnKeyValidator = webAuthnKeyValidator; _webauthnKeyValidator = webAuthnKeyValidator;
_deviceValidator = deviceValidator;
} }
[HttpPost("regenerate-keys")] [HttpPost("regenerate-keys")]
@ -109,6 +113,7 @@ public class AccountsKeyManagementController : Controller
EmergencyAccesses = await _emergencyAccessValidator.ValidateAsync(user, model.AccountUnlockData.EmergencyAccessUnlockData), EmergencyAccesses = await _emergencyAccessValidator.ValidateAsync(user, model.AccountUnlockData.EmergencyAccessUnlockData),
OrganizationUsers = await _organizationUserValidator.ValidateAsync(user, model.AccountUnlockData.OrganizationAccountRecoveryUnlockData), OrganizationUsers = await _organizationUserValidator.ValidateAsync(user, model.AccountUnlockData.OrganizationAccountRecoveryUnlockData),
WebAuthnKeys = await _webauthnKeyValidator.ValidateAsync(user, model.AccountUnlockData.PasskeyUnlockData), WebAuthnKeys = await _webauthnKeyValidator.ValidateAsync(user, model.AccountUnlockData.PasskeyUnlockData),
DeviceKeys = await _deviceValidator.ValidateAsync(user, model.AccountUnlockData.DeviceKeyUnlockData),
Ciphers = await _cipherValidator.ValidateAsync(user, model.AccountData.Ciphers), Ciphers = await _cipherValidator.ValidateAsync(user, model.AccountData.Ciphers),
Folders = await _folderValidator.ValidateAsync(user, model.AccountData.Folders), Folders = await _folderValidator.ValidateAsync(user, model.AccountData.Folders),

View File

@ -3,6 +3,7 @@ using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.Auth.Models.Request; using Bit.Api.Auth.Models.Request;
using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Auth.Models.Request.Accounts;
using Bit.Api.Auth.Models.Request.WebAuthn; using Bit.Api.Auth.Models.Request.WebAuthn;
using Bit.Core.Auth.Models.Api.Request;
namespace Bit.Api.KeyManagement.Models.Requests; namespace Bit.Api.KeyManagement.Models.Requests;
@ -13,4 +14,5 @@ public class UnlockDataRequestModel
public required IEnumerable<EmergencyAccessWithIdRequestModel> EmergencyAccessUnlockData { get; set; } public required IEnumerable<EmergencyAccessWithIdRequestModel> EmergencyAccessUnlockData { get; set; }
public required IEnumerable<ResetPasswordWithOrgIdRequestModel> OrganizationAccountRecoveryUnlockData { get; set; } public required IEnumerable<ResetPasswordWithOrgIdRequestModel> OrganizationAccountRecoveryUnlockData { get; set; }
public required IEnumerable<WebAuthnLoginRotateKeyRequestModel> PasskeyUnlockData { get; set; } public required IEnumerable<WebAuthnLoginRotateKeyRequestModel> PasskeyUnlockData { get; set; }
public required IEnumerable<OtherDeviceKeysUpdateRequestModel> DeviceKeyUnlockData { get; set; }
} }

View File

@ -0,0 +1,53 @@
using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Auth.Utilities;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
namespace Bit.Api.KeyManagement.Validators;
/// <summary>
/// Device implementation for <see cref="IRotationValidator{T,R}"/>
/// </summary>
public class DeviceRotationValidator : IRotationValidator<IEnumerable<OtherDeviceKeysUpdateRequestModel>, IEnumerable<Device>>
{
private readonly IDeviceRepository _deviceRepository;
/// <summary>
/// Instantiates a new <see cref="DeviceRotationValidator"/>
/// </summary>
/// <param name="deviceRepository">Retrieves all user <see cref="Device"/>s</param>
public DeviceRotationValidator(IDeviceRepository deviceRepository)
{
_deviceRepository = deviceRepository;
}
public async Task<IEnumerable<Device>> ValidateAsync(User user, IEnumerable<OtherDeviceKeysUpdateRequestModel> devices)
{
var result = new List<Device>();
var existingTrustedDevices = (await _deviceRepository.GetManyByUserIdAsync(user.Id)).Where(d => d.IsTrusted()).ToList();
if (existingTrustedDevices.Count == 0)
{
return result;
}
foreach (var existing in existingTrustedDevices)
{
var device = devices.FirstOrDefault(c => c.DeviceId == existing.Id);
if (device == null)
{
throw new BadRequestException("All existing trusted devices must be included in the rotation.");
}
if (device.EncryptedUserKey == null || device.EncryptedPublicKey == null)
{
throw new BadRequestException("Rotated encryption keys must be provided for all devices that are trusted.");
}
result.Add(device.ToDevice(existing));
}
return result;
}
}

View File

@ -22,6 +22,7 @@ public class NotificationResponseModel : ResponseModel
Title = notificationStatusDetails.Title; Title = notificationStatusDetails.Title;
Body = notificationStatusDetails.Body; Body = notificationStatusDetails.Body;
Date = notificationStatusDetails.RevisionDate; Date = notificationStatusDetails.RevisionDate;
TaskId = notificationStatusDetails.TaskId;
ReadDate = notificationStatusDetails.ReadDate; ReadDate = notificationStatusDetails.ReadDate;
DeletedDate = notificationStatusDetails.DeletedDate; DeletedDate = notificationStatusDetails.DeletedDate;
} }
@ -40,6 +41,8 @@ public class NotificationResponseModel : ResponseModel
public DateTime Date { get; set; } public DateTime Date { get; set; }
public Guid? TaskId { get; set; }
public DateTime? ReadDate { get; set; } public DateTime? ReadDate { get; set; }
public DateTime? DeletedDate { get; set; } public DateTime? DeletedDate { get; set; }

View File

@ -5,6 +5,7 @@ using Bit.Core.Settings;
using AspNetCoreRateLimit; using AspNetCoreRateLimit;
using Stripe; using Stripe;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Duende.IdentityModel;
using System.Globalization; using System.Globalization;
using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.Auth.Models.Request; using Bit.Api.Auth.Models.Request;
@ -30,8 +31,7 @@ using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.Identity.TokenProviders; using Bit.Core.Auth.Identity.TokenProviders;
using Bit.Core.Tools.ImportFeatures; using Bit.Core.Tools.ImportFeatures;
using Bit.Core.Tools.ReportFeatures; using Bit.Core.Tools.ReportFeatures;
using Duende.IdentityModel; using Bit.Core.Auth.Models.Api.Request;
#if !OSS #if !OSS
using Bit.Commercial.Core.SecretsManager; using Bit.Commercial.Core.SecretsManager;
@ -168,6 +168,9 @@ public class Startup
services services
.AddScoped<IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>>, .AddScoped<IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>>,
WebAuthnLoginKeyRotationValidator>(); WebAuthnLoginKeyRotationValidator>();
services
.AddScoped<IRotationValidator<IEnumerable<OtherDeviceKeysUpdateRequestModel>, IEnumerable<Device>>,
DeviceRotationValidator>();
// Services // Services
services.AddBaseServices(globalSettings); services.AddBaseServices(globalSettings);

View File

@ -16,6 +16,7 @@ using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Tools.Services; using Bit.Core.Tools.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Bit.Core.Vault.Authorization.Permissions;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Models.Data;
using Bit.Core.Vault.Queries; using Bit.Core.Vault.Queries;
@ -345,6 +346,77 @@ public class CiphersController : Controller
return await CanEditCiphersAsync(organizationId, cipherIds); return await CanEditCiphersAsync(organizationId, cipherIds);
} }
private async Task<bool> CanDeleteOrRestoreCipherAsAdminAsync(Guid organizationId, IEnumerable<Guid> cipherIds)
{
if (!_featureService.IsEnabled(FeatureFlagKeys.LimitItemDeletion))
{
return await CanEditCipherAsAdminAsync(organizationId, cipherIds);
}
var org = _currentContext.GetOrganization(organizationId);
// If we're not an "admin", we don't need to check the ciphers
if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true }))
{
// Are we a provider user? If so, we need to be sure we're not restricted
// Once the feature flag is removed, this check can be combined with the above
if (await _currentContext.ProviderUserForOrgAsync(organizationId))
{
// Provider is restricted from editing ciphers, so we're not an "admin"
if (_featureService.IsEnabled(FeatureFlagKeys.RestrictProviderAccess))
{
return false;
}
// Provider is unrestricted, so we're an "admin", don't return early
}
else
{
// Not a provider or admin
return false;
}
}
// If the user can edit all ciphers for the organization, just check they all belong to the org
if (await CanEditAllCiphersAsync(organizationId))
{
// TODO: This can likely be optimized to only query the requested ciphers and then checking they belong to the org
var orgCiphers = (await _cipherRepository.GetManyByOrganizationIdAsync(organizationId)).ToDictionary(c => c.Id);
// Ensure all requested ciphers are in orgCiphers
return cipherIds.All(c => orgCiphers.ContainsKey(c));
}
// The user cannot access any ciphers for the organization, we're done
if (!await CanAccessOrganizationCiphersAsync(organizationId))
{
return false;
}
var user = await _userService.GetUserByPrincipalAsync(User);
// Select all deletable ciphers for this user belonging to the organization
var deletableOrgCipherList = (await _cipherRepository.GetManyByUserIdAsync(user.Id, true))
.Where(c => c.OrganizationId == organizationId && c.UserId == null).ToList();
// Special case for unassigned ciphers
if (await CanAccessUnassignedCiphersAsync(organizationId))
{
var unassignedCiphers =
(await _cipherRepository.GetManyUnassignedOrganizationDetailsByOrganizationIdAsync(
organizationId));
// Users that can access unassigned ciphers can also delete them
deletableOrgCipherList.AddRange(unassignedCiphers.Select(c => new CipherDetails(c) { Manage = true }));
}
var organizationAbility = await _applicationCacheService.GetOrganizationAbilityAsync(organizationId);
var deletableOrgCiphers = deletableOrgCipherList
.Where(c => NormalCipherPermissions.CanDelete(user, c, organizationAbility))
.ToDictionary(c => c.Id);
return cipherIds.All(c => deletableOrgCiphers.ContainsKey(c));
}
/// <summary> /// <summary>
/// TODO: Move this to its own authorization handler or equivalent service - AC-2062 /// TODO: Move this to its own authorization handler or equivalent service - AC-2062
/// </summary> /// </summary>
@ -763,12 +835,12 @@ public class CiphersController : Controller
[HttpDelete("{id}/admin")] [HttpDelete("{id}/admin")]
[HttpPost("{id}/delete-admin")] [HttpPost("{id}/delete-admin")]
public async Task DeleteAdmin(string id) public async Task DeleteAdmin(Guid id)
{ {
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); var cipher = await GetByIdAsync(id, userId);
if (cipher == null || !cipher.OrganizationId.HasValue || if (cipher == null || !cipher.OrganizationId.HasValue ||
!await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })) !await CanDeleteOrRestoreCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id }))
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
@ -808,7 +880,7 @@ public class CiphersController : Controller
var cipherIds = model.Ids.Select(i => new Guid(i)).ToList(); var cipherIds = model.Ids.Select(i => new Guid(i)).ToList();
if (string.IsNullOrWhiteSpace(model.OrganizationId) || if (string.IsNullOrWhiteSpace(model.OrganizationId) ||
!await CanEditCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds)) !await CanDeleteOrRestoreCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds))
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
@ -830,12 +902,12 @@ public class CiphersController : Controller
} }
[HttpPut("{id}/delete-admin")] [HttpPut("{id}/delete-admin")]
public async Task PutDeleteAdmin(string id) public async Task PutDeleteAdmin(Guid id)
{ {
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); var cipher = await GetByIdAsync(id, userId);
if (cipher == null || !cipher.OrganizationId.HasValue || if (cipher == null || !cipher.OrganizationId.HasValue ||
!await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })) !await CanDeleteOrRestoreCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id }))
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
@ -871,7 +943,7 @@ public class CiphersController : Controller
var cipherIds = model.Ids.Select(i => new Guid(i)).ToList(); var cipherIds = model.Ids.Select(i => new Guid(i)).ToList();
if (string.IsNullOrWhiteSpace(model.OrganizationId) || if (string.IsNullOrWhiteSpace(model.OrganizationId) ||
!await CanEditCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds)) !await CanDeleteOrRestoreCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds))
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
@ -899,12 +971,12 @@ public class CiphersController : Controller
} }
[HttpPut("{id}/restore-admin")] [HttpPut("{id}/restore-admin")]
public async Task<CipherMiniResponseModel> PutRestoreAdmin(string id) public async Task<CipherMiniResponseModel> PutRestoreAdmin(Guid id)
{ {
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id)); var cipher = await GetByIdAsync(id, userId);
if (cipher == null || !cipher.OrganizationId.HasValue || if (cipher == null || !cipher.OrganizationId.HasValue ||
!await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })) !await CanDeleteOrRestoreCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id }))
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
@ -944,7 +1016,7 @@ public class CiphersController : Controller
var cipherIdsToRestore = new HashSet<Guid>(model.Ids.Select(i => new Guid(i))); var cipherIdsToRestore = new HashSet<Guid>(model.Ids.Select(i => new Guid(i)));
if (model.OrganizationId == default || !await CanEditCipherAsAdminAsync(model.OrganizationId, cipherIdsToRestore)) if (model.OrganizationId == default || !await CanDeleteOrRestoreCipherAsAdminAsync(model.OrganizationId, cipherIdsToRestore))
{ {
throw new NotFoundException(); throw new NotFoundException();
} }

View File

@ -1,8 +1,11 @@
using Bit.Core.AdminConsole.Repositories; using Bit.Core;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
@ -12,6 +15,7 @@ using Event = Stripe.Event;
namespace Bit.Billing.Services.Implementations; namespace Bit.Billing.Services.Implementations;
public class UpcomingInvoiceHandler( public class UpcomingInvoiceHandler(
IFeatureService featureService,
ILogger<StripeEventProcessor> logger, ILogger<StripeEventProcessor> logger,
IMailService mailService, IMailService mailService,
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
@ -21,7 +25,8 @@ public class UpcomingInvoiceHandler(
IStripeEventService stripeEventService, IStripeEventService stripeEventService,
IStripeEventUtilityService stripeEventUtilityService, IStripeEventUtilityService stripeEventUtilityService,
IUserRepository userRepository, IUserRepository userRepository,
IValidateSponsorshipCommand validateSponsorshipCommand) IValidateSponsorshipCommand validateSponsorshipCommand,
IAutomaticTaxFactory automaticTaxFactory)
: IUpcomingInvoiceHandler : IUpcomingInvoiceHandler
{ {
public async Task HandleAsync(Event parsedEvent) public async Task HandleAsync(Event parsedEvent)
@ -136,6 +141,21 @@ public class UpcomingInvoiceHandler(
private async Task TryEnableAutomaticTaxAsync(Subscription subscription) private async Task TryEnableAutomaticTaxAsync(Subscription subscription)
{ {
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id));
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
if (updateOptions == null)
{
return;
}
await stripeFacade.UpdateSubscription(subscription.Id, updateOptions);
return;
}
if (subscription.AutomaticTax.Enabled || if (subscription.AutomaticTax.Enabled ||
!subscription.Customer.HasBillingLocation() || !subscription.Customer.HasBillingLocation() ||
await IsNonTaxableNonUSBusinessUseSubscription(subscription)) await IsNonTaxableNonUSBusinessUseSubscription(subscription))

View File

@ -34,6 +34,8 @@ public class ResetPasswordPolicyRequirementFactory : BasePolicyRequirementFactor
protected override IEnumerable<OrganizationUserType> ExemptRoles => []; protected override IEnumerable<OrganizationUserType> ExemptRoles => [];
protected override IEnumerable<OrganizationUserStatusType> ExemptStatuses => [OrganizationUserStatusType.Revoked];
public override ResetPasswordPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails) public override ResetPasswordPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails)
{ {
var result = policyDetails var result = policyDetails

View File

@ -568,6 +568,7 @@ public class OrganizationService : IOrganizationService
UseSecretsManager = license.UseSecretsManager, UseSecretsManager = license.UseSecretsManager,
SmSeats = license.SmSeats, SmSeats = license.SmSeats,
SmServiceAccounts = license.SmServiceAccounts, SmServiceAccounts = license.SmServiceAccounts,
UseRiskInsights = license.UseRiskInsights,
}; };
var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false); var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false);

View File

@ -1,4 +1,5 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using Bit.Core.Entities;
using Bit.Core.Utilities; using Bit.Core.Utilities;
namespace Bit.Core.Auth.Models.Api.Request; namespace Bit.Core.Auth.Models.Api.Request;
@ -7,6 +8,13 @@ public class OtherDeviceKeysUpdateRequestModel : DeviceKeysUpdateRequestModel
{ {
[Required] [Required]
public Guid DeviceId { get; set; } public Guid DeviceId { get; set; }
public Device ToDevice(Device existingDevice)
{
existingDevice.EncryptedPublicKey = EncryptedPublicKey;
existingDevice.EncryptedUserKey = EncryptedUserKey;
return existingDevice;
}
} }
public class DeviceKeysUpdateRequestModel public class DeviceKeysUpdateRequestModel

View File

@ -1,5 +1,4 @@
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.Utilities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models.Api; using Bit.Core.Models.Api;
@ -19,7 +18,7 @@ public class DeviceAuthRequestResponseModel : ResponseModel
Type = deviceAuthDetails.Type, Type = deviceAuthDetails.Type,
Identifier = deviceAuthDetails.Identifier, Identifier = deviceAuthDetails.Identifier,
CreationDate = deviceAuthDetails.CreationDate, CreationDate = deviceAuthDetails.CreationDate,
IsTrusted = deviceAuthDetails.IsTrusted() IsTrusted = deviceAuthDetails.IsTrusted,
}; };
if (deviceAuthDetails.AuthRequestId != null && deviceAuthDetails.AuthRequestCreatedAt != null) if (deviceAuthDetails.AuthRequestId != null && deviceAuthDetails.AuthRequestCreatedAt != null)

View File

@ -287,14 +287,14 @@ public class AuthRequestService : IAuthRequestService
private async Task NotifyAdminsOfDeviceApprovalRequestAsync(OrganizationUser organizationUser, User user) private async Task NotifyAdminsOfDeviceApprovalRequestAsync(OrganizationUser organizationUser, User user)
{ {
if (!_featureService.IsEnabled(FeatureFlagKeys.DeviceApprovalRequestAdminNotifications)) var adminEmails = await GetAdminAndAccountRecoveryEmailsAsync(organizationUser.OrganizationId);
if (adminEmails.Count == 0)
{ {
_logger.LogWarning("Skipped sending device approval notification to admins - feature flag disabled"); _logger.LogWarning("There are no admin emails to send to.");
return; return;
} }
var adminEmails = await GetAdminAndAccountRecoveryEmailsAsync(organizationUser.OrganizationId);
await _mailService.SendDeviceApprovalRequestedNotificationEmailAsync( await _mailService.SendDeviceApprovalRequestedNotificationEmailAsync(
adminEmails, adminEmails,
organizationUser.OrganizationId, organizationUser.OrganizationId,

View File

@ -47,6 +47,8 @@ public static class StripeConstants
public static class MetadataKeys public static class MetadataKeys
{ {
public const string OrganizationId = "organizationId"; public const string OrganizationId = "organizationId";
public const string ProviderId = "providerId";
public const string UserId = "userId";
} }
public static class PaymentBehavior public static class PaymentBehavior

View File

@ -21,7 +21,7 @@ public static class CustomerExtensions
/// <param name="customer"></param> /// <param name="customer"></param>
/// <returns></returns> /// <returns></returns>
public static bool HasTaxLocationVerified(this Customer customer) => public static bool HasTaxLocationVerified(this Customer customer) =>
customer?.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation;
public static decimal GetBillingBalance(this Customer customer) public static decimal GetBillingBalance(this Customer customer)
{ {

View File

@ -4,6 +4,7 @@ using Bit.Core.Billing.Licenses.Extensions;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
namespace Bit.Core.Billing.Extensions; namespace Bit.Core.Billing.Extensions;
@ -18,6 +19,9 @@ public static class ServiceCollectionExtensions
services.AddTransient<IPremiumUserBillingService, PremiumUserBillingService>(); services.AddTransient<IPremiumUserBillingService, PremiumUserBillingService>();
services.AddTransient<ISetupIntentCache, SetupIntentDistributedCache>(); services.AddTransient<ISetupIntentCache, SetupIntentDistributedCache>();
services.AddTransient<ISubscriberService, SubscriberService>(); services.AddTransient<ISubscriberService, SubscriberService>();
services.AddKeyedTransient<IAutomaticTaxStrategy, PersonalUseAutomaticTaxStrategy>(AutomaticTaxFactory.PersonalUse);
services.AddKeyedTransient<IAutomaticTaxStrategy, BusinessUseAutomaticTaxStrategy>(AutomaticTaxFactory.BusinessUse);
services.AddTransient<IAutomaticTaxFactory, AutomaticTaxFactory>();
services.AddLicenseServices(); services.AddLicenseServices();
services.AddPricingClient(); services.AddPricingClient();
} }

View File

@ -1,26 +0,0 @@
using Stripe;
namespace Bit.Core.Billing.Extensions;
public static class SubscriptionCreateOptionsExtensions
{
/// <summary>
/// Attempts to enable automatic tax for given new subscription options.
/// </summary>
/// <param name="options"></param>
/// <param name="customer">The existing customer.</param>
/// <returns>Returns true when successful, false when conditions are not met.</returns>
public static bool EnableAutomaticTax(this SubscriptionCreateOptions options, Customer customer)
{
// We might only need to check the automatic tax status.
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country))
{
return false;
}
options.DefaultTaxRates = [];
options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
return true;
}
}

View File

@ -309,8 +309,7 @@ public class ProviderMigrator(
.SeatMinimum ?? 0; .SeatMinimum ?? 0;
var updateSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand( var updateSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand(
provider.Id, provider,
provider.GatewaySubscriptionId,
[ [
(Plan: PlanType.EnterpriseMonthly, SeatsMinimum: enterpriseSeatMinimum), (Plan: PlanType.EnterpriseMonthly, SeatsMinimum: enterpriseSeatMinimum),
(Plan: PlanType.TeamsMonthly, SeatsMinimum: teamsSeatMinimum) (Plan: PlanType.TeamsMonthly, SeatsMinimum: teamsSeatMinimum)

View File

@ -75,6 +75,7 @@ public abstract record Plan
// Seats // Seats
public string StripePlanId { get; init; } public string StripePlanId { get; init; }
public string StripeSeatPlanId { get; init; } public string StripeSeatPlanId { get; init; }
[Obsolete("No longer used to retrieve a provider's price ID. Use ProviderPriceAdapter instead.")]
public string StripeProviderPortalSeatPlanId { get; init; } public string StripeProviderPortalSeatPlanId { get; init; }
public decimal BasePrice { get; init; } public decimal BasePrice { get; init; }
public decimal SeatPrice { get; init; } public decimal SeatPrice { get; init; }

View File

@ -0,0 +1,30 @@
#nullable enable
using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
namespace Bit.Core.Billing.Services.Contracts;
public class AutomaticTaxFactoryParameters
{
public AutomaticTaxFactoryParameters(PlanType planType)
{
PlanType = planType;
}
public AutomaticTaxFactoryParameters(ISubscriber subscriber, IEnumerable<string> prices)
{
Subscriber = subscriber;
Prices = prices;
}
public AutomaticTaxFactoryParameters(IEnumerable<string> prices)
{
Prices = prices;
}
public ISubscriber? Subscriber { get; init; }
public PlanType? PlanType { get; init; }
public IEnumerable<string>? Prices { get; init; }
}

View File

@ -1,8 +1,9 @@
using Bit.Core.Billing.Enums; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Enums;
namespace Bit.Core.Billing.Services.Contracts; namespace Bit.Core.Billing.Services.Contracts;
public record ChangeProviderPlanCommand( public record ChangeProviderPlanCommand(
Provider Provider,
Guid ProviderPlanId, Guid ProviderPlanId,
PlanType NewPlan, PlanType NewPlan);
string GatewaySubscriptionId);

View File

@ -1,10 +1,10 @@
using Bit.Core.Billing.Enums; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Enums;
namespace Bit.Core.Billing.Services.Contracts; namespace Bit.Core.Billing.Services.Contracts;
/// <param name="Id">The ID of the provider to update the seat minimums for.</param> /// <param name="Provider">The provider to update the seat minimums for.</param>
/// <param name="Configuration">The new seat minimums for the provider.</param> /// <param name="Configuration">The new seat minimums for the provider.</param>
public record UpdateProviderSeatMinimumsCommand( public record UpdateProviderSeatMinimumsCommand(
Guid Id, Provider Provider,
string GatewaySubscriptionId,
IReadOnlyCollection<(PlanType Plan, int SeatsMinimum)> Configuration); IReadOnlyCollection<(PlanType Plan, int SeatsMinimum)> Configuration);

View File

@ -0,0 +1,11 @@
using Bit.Core.Billing.Services.Contracts;
namespace Bit.Core.Billing.Services;
/// <summary>
/// Responsible for defining the correct automatic tax strategy for either personal use of business use.
/// </summary>
public interface IAutomaticTaxFactory
{
Task<IAutomaticTaxStrategy> CreateAsync(AutomaticTaxFactoryParameters parameters);
}

View File

@ -0,0 +1,33 @@
#nullable enable
using Stripe;
namespace Bit.Core.Billing.Services;
public interface IAutomaticTaxStrategy
{
/// <summary>
///
/// </summary>
/// <param name="subscription"></param>
/// <returns>
/// Returns <see cref="SubscriptionUpdateOptions" /> if changes are to be applied to the subscription, returns null
/// otherwise.
/// </returns>
SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription);
/// <summary>
/// Modifies an existing <see cref="SubscriptionCreateOptions" /> object with the automatic tax flag set correctly.
/// </summary>
/// <param name="options"></param>
/// <param name="customer"></param>
void SetCreateOptions(SubscriptionCreateOptions options, Customer customer);
/// <summary>
/// Modifies an existing <see cref="SubscriptionUpdateOptions" /> object with the automatic tax flag set correctly.
/// </summary>
/// <param name="options"></param>
/// <param name="subscription"></param>
void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription);
void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options);
}

View File

@ -0,0 +1,50 @@
#nullable enable
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Entities;
using Bit.Core.Services;
namespace Bit.Core.Billing.Services.Implementations.AutomaticTax;
public class AutomaticTaxFactory(
IFeatureService featureService,
IPricingClient pricingClient) : IAutomaticTaxFactory
{
public const string BusinessUse = "business-use";
public const string PersonalUse = "personal-use";
private readonly Lazy<Task<IEnumerable<string>>> _personalUsePlansTask = new(async () =>
{
var plans = await Task.WhenAll(
pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019),
pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually));
return plans.Select(plan => plan.PasswordManager.StripePlanId);
});
public async Task<IAutomaticTaxStrategy> CreateAsync(AutomaticTaxFactoryParameters parameters)
{
if (parameters.Subscriber is User)
{
return new PersonalUseAutomaticTaxStrategy(featureService);
}
if (parameters.PlanType.HasValue)
{
var plan = await pricingClient.GetPlanOrThrow(parameters.PlanType.Value);
return plan.CanBeUsedByBusiness
? new BusinessUseAutomaticTaxStrategy(featureService)
: new PersonalUseAutomaticTaxStrategy(featureService);
}
var personalUsePlans = await _personalUsePlansTask.Value;
if (parameters.Prices != null && parameters.Prices.Any(x => personalUsePlans.Any(y => y == x)))
{
return new PersonalUseAutomaticTaxStrategy(featureService);
}
return new BusinessUseAutomaticTaxStrategy(featureService);
}
}

View File

@ -0,0 +1,96 @@
#nullable enable
using Bit.Core.Billing.Extensions;
using Bit.Core.Services;
using Stripe;
namespace Bit.Core.Billing.Services.Implementations.AutomaticTax;
public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : IAutomaticTaxStrategy
{
public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription)
{
if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
{
return null;
}
var shouldBeEnabled = ShouldBeEnabled(subscription.Customer);
if (subscription.AutomaticTax.Enabled == shouldBeEnabled)
{
return null;
}
var options = new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = shouldBeEnabled
},
DefaultTaxRates = []
};
return options;
}
public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer)
{
options.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = ShouldBeEnabled(customer)
};
}
public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription)
{
if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
{
return;
}
var shouldBeEnabled = ShouldBeEnabled(subscription.Customer);
if (subscription.AutomaticTax.Enabled == shouldBeEnabled)
{
return;
}
options.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = shouldBeEnabled
};
options.DefaultTaxRates = [];
}
public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options)
{
options.AutomaticTax ??= new InvoiceAutomaticTaxOptions();
if (options.CustomerDetails.Address.Country == "US")
{
options.AutomaticTax.Enabled = true;
return;
}
options.AutomaticTax.Enabled = options.CustomerDetails.TaxIds != null && options.CustomerDetails.TaxIds.Any();
}
private bool ShouldBeEnabled(Customer customer)
{
if (!customer.HasTaxLocationVerified())
{
return false;
}
if (customer.Address.Country == "US")
{
return true;
}
if (customer.TaxIds == null)
{
throw new ArgumentNullException(nameof(customer.TaxIds), "`customer.tax_ids` must be expanded.");
}
return customer.TaxIds.Any();
}
}

View File

@ -0,0 +1,64 @@
#nullable enable
using Bit.Core.Billing.Extensions;
using Bit.Core.Services;
using Stripe;
namespace Bit.Core.Billing.Services.Implementations.AutomaticTax;
public class PersonalUseAutomaticTaxStrategy(IFeatureService featureService) : IAutomaticTaxStrategy
{
public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer)
{
options.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = ShouldBeEnabled(customer)
};
}
public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription)
{
if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
{
return;
}
options.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = ShouldBeEnabled(subscription.Customer)
};
options.DefaultTaxRates = [];
}
public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription)
{
if (!featureService.IsEnabled(FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
{
return null;
}
if (subscription.AutomaticTax.Enabled == ShouldBeEnabled(subscription.Customer))
{
return null;
}
var options = new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = ShouldBeEnabled(subscription.Customer),
},
DefaultTaxRates = []
};
return options;
}
public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options)
{
options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true };
}
private static bool ShouldBeEnabled(Customer customer)
{
return customer.HasTaxLocationVerified();
}
}

View File

@ -1,9 +1,11 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Caches; using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Repositories; using Bit.Core.Repositories;
@ -23,6 +25,7 @@ namespace Bit.Core.Billing.Services.Implementations;
public class OrganizationBillingService( public class OrganizationBillingService(
IBraintreeGateway braintreeGateway, IBraintreeGateway braintreeGateway,
IFeatureService featureService,
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
ILogger<OrganizationBillingService> logger, ILogger<OrganizationBillingService> logger,
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
@ -30,7 +33,8 @@ public class OrganizationBillingService(
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService, ISubscriberService subscriberService,
ITaxService taxService) : IOrganizationBillingService ITaxService taxService,
IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService
{ {
public async Task Finalize(OrganizationSale sale) public async Task Finalize(OrganizationSale sale)
{ {
@ -143,7 +147,7 @@ public class OrganizationBillingService(
Coupon = customerSetup.Coupon, Coupon = customerSetup.Coupon,
Description = organization.DisplayBusinessName(), Description = organization.DisplayBusinessName(),
Email = organization.BillingEmail, Email = organization.BillingEmail,
Expand = ["tax"], Expand = ["tax", "tax_ids"],
InvoiceSettings = new CustomerInvoiceSettingsOptions InvoiceSettings = new CustomerInvoiceSettingsOptions
{ {
CustomFields = [ CustomFields = [
@ -369,21 +373,8 @@ public class OrganizationBillingService(
} }
} }
var customerHasTaxInfo = customer is
{
Address:
{
Country: not null and not "",
PostalCode: not null and not ""
}
};
var subscriptionCreateOptions = new SubscriptionCreateOptions var subscriptionCreateOptions = new SubscriptionCreateOptions
{ {
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = customerHasTaxInfo
},
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
Customer = customer.Id, Customer = customer.Id,
Items = subscriptionItemOptionsList, Items = subscriptionItemOptionsList,
@ -395,6 +386,18 @@ public class OrganizationBillingService(
TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays
}; };
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType);
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
{
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions();
subscriptionCreateOptions.AutomaticTax.Enabled = customer.HasBillingLocation();
}
return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
} }

View File

@ -2,6 +2,7 @@
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
@ -9,6 +10,7 @@ using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Braintree; using Braintree;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Stripe; using Stripe;
using Customer = Stripe.Customer; using Customer = Stripe.Customer;
@ -20,19 +22,21 @@ using static Utilities;
public class PremiumUserBillingService( public class PremiumUserBillingService(
IBraintreeGateway braintreeGateway, IBraintreeGateway braintreeGateway,
IFeatureService featureService,
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
ILogger<PremiumUserBillingService> logger, ILogger<PremiumUserBillingService> logger,
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService, ISubscriberService subscriberService,
IUserRepository userRepository) : IPremiumUserBillingService IUserRepository userRepository,
[FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService
{ {
public async Task Credit(User user, decimal amount) public async Task Credit(User user, decimal amount)
{ {
var customer = await subscriberService.GetCustomer(user); var customer = await subscriberService.GetCustomer(user);
// Negative credit represents a balance and all Stripe denomination is in cents. // Negative credit represents a balance and all Stripe denomination is in cents.
var credit = (long)amount * -100; var credit = (long)(amount * -100);
if (customer == null) if (customer == null)
{ {
@ -318,10 +322,6 @@ public class PremiumUserBillingService(
var subscriptionCreateOptions = new SubscriptionCreateOptions var subscriptionCreateOptions = new SubscriptionCreateOptions
{ {
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported,
},
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
Customer = customer.Id, Customer = customer.Id,
Items = subscriptionItemOptionsList, Items = subscriptionItemOptionsList,
@ -335,6 +335,18 @@ public class PremiumUserBillingService(
OffSession = true OffSession = true
}; };
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported,
};
}
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
if (usingPayPal) if (usingPayPal)

View File

@ -1,6 +1,7 @@
using Bit.Core.Billing.Caches; using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
@ -20,11 +21,13 @@ namespace Bit.Core.Billing.Services.Implementations;
public class SubscriberService( public class SubscriberService(
IBraintreeGateway braintreeGateway, IBraintreeGateway braintreeGateway,
IFeatureService featureService,
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
ILogger<SubscriberService> logger, ILogger<SubscriberService> logger,
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ITaxService taxService) : ISubscriberService ITaxService taxService,
IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService
{ {
public async Task CancelSubscription( public async Task CancelSubscription(
ISubscriber subscriber, ISubscriber subscriber,
@ -438,7 +441,8 @@ public class SubscriberService(
ArgumentNullException.ThrowIfNull(subscriber); ArgumentNullException.ThrowIfNull(subscriber);
ArgumentNullException.ThrowIfNull(tokenizedPaymentSource); ArgumentNullException.ThrowIfNull(tokenizedPaymentSource);
var customer = await GetCustomerOrThrow(subscriber); var customerGetOptions = new CustomerGetOptions { Expand = ["tax", "tax_ids"] };
var customer = await GetCustomerOrThrow(subscriber, customerGetOptions);
var (type, token) = tokenizedPaymentSource; var (type, token) = tokenizedPaymentSource;
@ -597,7 +601,7 @@ public class SubscriberService(
Expand = ["subscriptions", "tax", "tax_ids"] Expand = ["subscriptions", "tax", "tax_ids"]
}); });
await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions customer = await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{ {
Address = new AddressOptions Address = new AddressOptions
{ {
@ -607,7 +611,8 @@ public class SubscriberService(
Line2 = taxInformation.Line2, Line2 = taxInformation.Line2,
City = taxInformation.City, City = taxInformation.City,
State = taxInformation.State State = taxInformation.State
} },
Expand = ["subscriptions", "tax", "tax_ids"]
}); });
var taxId = customer.TaxIds?.FirstOrDefault(); var taxId = customer.TaxIds?.FirstOrDefault();
@ -661,21 +666,42 @@ public class SubscriberService(
} }
} }
if (SubscriberIsEligibleForAutomaticTax(subscriber, customer)) if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{ {
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
new SubscriptionUpdateOptions {
var subscriptionGetOptions = new SubscriptionGetOptions
{ {
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } Expand = ["customer.tax", "customer.tax_ids"]
}); };
var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions);
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id));
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
if (automaticTaxOptions?.AutomaticTax?.Enabled != null)
{
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions);
}
}
} }
else
{
if (SubscriberIsEligibleForAutomaticTax(subscriber, customer))
{
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
return; return;
bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer) bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer)
=> !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) && => !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) &&
(localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) && (localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) &&
localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported;
}
} }
public async Task VerifyBankAccount( public async Task VerifyBankAccount(

View File

@ -113,6 +113,7 @@ public static class FeatureFlagKeys
/* Auth Team */ /* Auth Team */
public const string PM9112DeviceApprovalPersistence = "pm-9112-device-approval-persistence"; public const string PM9112DeviceApprovalPersistence = "pm-9112-device-approval-persistence";
public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence";
public const string DuoRedirect = "duo-redirect"; public const string DuoRedirect = "duo-redirect";
public const string EmailVerification = "email-verification"; public const string EmailVerification = "email-verification";
public const string EmailVerificationDisableTimingDelays = "email-verification-disable-timing-delays"; public const string EmailVerificationDisableTimingDelays = "email-verification-disable-timing-delays";
@ -148,6 +149,8 @@ public static class FeatureFlagKeys
public const string P15179_AddExistingOrgsFromProviderPortal = "pm-15179-add-existing-orgs-from-provider-portal"; public const string P15179_AddExistingOrgsFromProviderPortal = "pm-15179-add-existing-orgs-from-provider-portal";
public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features"; public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features";
public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method"; public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method";
public const string PM19147_AutomaticTaxImprovements = "pm-19147-automatic-tax-improvements";
public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates";
/* Key Management Team */ /* Key Management Team */
public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair";
@ -169,6 +172,7 @@ public static class FeatureFlagKeys
public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication"; public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication";
public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync"; public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync";
public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias"; public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias";
public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias"; public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias";
/* Platform Team */ /* Platform Team */

View File

@ -23,8 +23,8 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="AspNetCoreRateLimit.Redis" Version="2.0.0" /> <PackageReference Include="AspNetCoreRateLimit.Redis" Version="2.0.0" />
<PackageReference Include="AWSSDK.SimpleEmail" Version="3.7.402.28" /> <PackageReference Include="AWSSDK.SimpleEmail" Version="3.7.402.61" />
<PackageReference Include="AWSSDK.SQS" Version="3.7.400.85" /> <PackageReference Include="AWSSDK.SQS" Version="3.7.400.118" />
<PackageReference Include="Azure.Data.Tables" Version="12.9.0" /> <PackageReference Include="Azure.Data.Tables" Version="12.9.0" />
<PackageReference Include="Azure.Extensions.AspNetCore.DataProtection.Blobs" Version="1.3.4" /> <PackageReference Include="Azure.Extensions.AspNetCore.DataProtection.Blobs" Version="1.3.4" />
<PackageReference Include="Microsoft.AspNetCore.DataProtection" Version="8.0.10" /> <PackageReference Include="Microsoft.AspNetCore.DataProtection" Version="8.0.10" />
@ -61,7 +61,7 @@
<PackageReference Include="Otp.NET" Version="1.4.0" /> <PackageReference Include="Otp.NET" Version="1.4.0" />
<PackageReference Include="YubicoDotNetClient" Version="1.2.0" /> <PackageReference Include="YubicoDotNetClient" Version="1.2.0" />
<PackageReference Include="Microsoft.Extensions.Caching.StackExchangeRedis" Version="8.0.10" /> <PackageReference Include="Microsoft.Extensions.Caching.StackExchangeRedis" Version="8.0.10" />
<PackageReference Include="LaunchDarkly.ServerSdk" Version="8.6.0" /> <PackageReference Include="LaunchDarkly.ServerSdk" Version="8.7.0" />
<PackageReference Include="Quartz" Version="3.13.1" /> <PackageReference Include="Quartz" Version="3.13.1" />
<PackageReference Include="Quartz.Extensions.Hosting" Version="3.13.1" /> <PackageReference Include="Quartz.Extensions.Hosting" Version="3.13.1" />
<PackageReference Include="Quartz.Extensions.DependencyInjection" Version="3.13.1" /> <PackageReference Include="Quartz.Extensions.DependencyInjection" Version="3.13.1" />

View File

@ -14,5 +14,7 @@ public enum ClientType : byte
[Display(Name = "Desktop App")] [Display(Name = "Desktop App")]
Desktop = 3, Desktop = 3,
[Display(Name = "Mobile App")] [Display(Name = "Mobile App")]
Mobile = 4 Mobile = 4,
[Display(Name = "CLI")]
Cli = 5
} }

View File

@ -20,6 +20,7 @@ public class RotateUserAccountKeysData
public IEnumerable<EmergencyAccess> EmergencyAccesses { get; set; } public IEnumerable<EmergencyAccess> EmergencyAccesses { get; set; }
public IReadOnlyList<OrganizationUser> OrganizationUsers { get; set; } public IReadOnlyList<OrganizationUser> OrganizationUsers { get; set; }
public IEnumerable<WebAuthnLoginRotateKeyData> WebAuthnKeys { get; set; } public IEnumerable<WebAuthnLoginRotateKeyData> WebAuthnKeys { get; set; }
public IEnumerable<Device> DeviceKeys { get; set; }
// User vault data encrypted by the userkey // User vault data encrypted by the userkey
public IEnumerable<Cipher> Ciphers { get; set; } public IEnumerable<Cipher> Ciphers { get; set; }

View File

@ -20,6 +20,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
private readonly ISendRepository _sendRepository; private readonly ISendRepository _sendRepository;
private readonly IEmergencyAccessRepository _emergencyAccessRepository; private readonly IEmergencyAccessRepository _emergencyAccessRepository;
private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IDeviceRepository _deviceRepository;
private readonly IPushNotificationService _pushService; private readonly IPushNotificationService _pushService;
private readonly IdentityErrorDescriber _identityErrorDescriber; private readonly IdentityErrorDescriber _identityErrorDescriber;
private readonly IWebAuthnCredentialRepository _credentialRepository; private readonly IWebAuthnCredentialRepository _credentialRepository;
@ -42,6 +43,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
public RotateUserAccountKeysCommand(IUserService userService, IUserRepository userRepository, public RotateUserAccountKeysCommand(IUserService userService, IUserRepository userRepository,
ICipherRepository cipherRepository, IFolderRepository folderRepository, ISendRepository sendRepository, ICipherRepository cipherRepository, IFolderRepository folderRepository, ISendRepository sendRepository,
IEmergencyAccessRepository emergencyAccessRepository, IOrganizationUserRepository organizationUserRepository, IEmergencyAccessRepository emergencyAccessRepository, IOrganizationUserRepository organizationUserRepository,
IDeviceRepository deviceRepository,
IPasswordHasher<User> passwordHasher, IPasswordHasher<User> passwordHasher,
IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository) IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository)
{ {
@ -52,6 +54,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
_sendRepository = sendRepository; _sendRepository = sendRepository;
_emergencyAccessRepository = emergencyAccessRepository; _emergencyAccessRepository = emergencyAccessRepository;
_organizationUserRepository = organizationUserRepository; _organizationUserRepository = organizationUserRepository;
_deviceRepository = deviceRepository;
_pushService = pushService; _pushService = pushService;
_identityErrorDescriber = errors; _identityErrorDescriber = errors;
_credentialRepository = credentialRepository; _credentialRepository = credentialRepository;
@ -127,6 +130,11 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand
saveEncryptedDataActions.Add(_credentialRepository.UpdateKeysForRotationAsync(user.Id, model.WebAuthnKeys)); saveEncryptedDataActions.Add(_credentialRepository.UpdateKeysForRotationAsync(user.Id, model.WebAuthnKeys));
} }
if (model.DeviceKeys.Any())
{
saveEncryptedDataActions.Add(_deviceRepository.UpdateKeysForRotationAsync(user.Id, model.DeviceKeys));
}
await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions); await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions);
await _pushService.PushLogOutAsync(user.Id); await _pushService.PushLogOutAsync(user.Id);
return IdentityResult.Success; return IdentityResult.Success;

View File

@ -6,11 +6,8 @@
<table border="0" cellpadding="0" cellspacing="0" width="100%" <table border="0" cellpadding="0" cellspacing="0" width="100%"
style="padding-left:30px; padding-right: 5px; padding-top: 20px;"> style="padding-left:30px; padding-right: 5px; padding-top: 20px;">
<tr> <tr>
<td <td style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 24px; color: #ffffff; line-height: 32px; font-weight: 500; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 24px; color: #ffffff; line-height: 32px; font-weight: 500; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> {{OrgName}} has identified {{TaskCount}} critical {{plurality TaskCount "login" "logins"}} that {{plurality TaskCount "requires" "require"}} a password change
{{OrgName}} has identified {{TaskCount}} critical login{{#if TaskCountPlural}}s{{/if}} that require{{#unless
TaskCountPlural}}s{{/unless}} a
password change
</td> </td>
</tr> </tr>
</table> </table>

View File

@ -1,7 +1,5 @@
{{#>FullTextLayout}} {{#>FullTextLayout}}
{{OrgName}} has identified {{TaskCount}} critical login{{#if TaskCountPlural}}s{{/if}} that require{{#unless {{OrgName}} has identified {{TaskCount}} critical {{plurality TaskCount "login" "logins"}} that {{plurality TaskCount "requires" "require"}} a password change
TaskCountPlural}}s{{/unless}} a
password change
{{>@partial-block}} {{>@partial-block}}

View File

@ -14,18 +14,17 @@
</td> </td>
</tr> </tr>
</table> </table>
<table width="100%" border="0" cellpadding="0" cellspacing="0" <table width="100%" border="0" cellpadding="0" cellspacing="0" style="padding-bottom: 24px; padding-left: 24px; padding-right: 24px; text-align: center;" align="center">
style="display: table; width:100%; padding-bottom: 24px; text-align: center;" align="center">
<tr> <tr>
<td display="display: table-cell"> <td>
<a href="{{ReviewPasswordsUrl}}" clicktracking=off target="_blank" <a href="{{ReviewPasswordsUrl}}" clicktracking=off target="_blank"
style="display: inline-block; font-weight: bold; color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; border-radius: 999px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> style="display: inline-block; font-weight: bold; color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; border-radius: 999px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
Review at-risk passwords Review at-risk passwords
</a> </a>
</td> </td>
</tr> </tr>
<table width="100%" border="0" cellpadding="0" cellspacing="0" </table>
style="display: table; width:100%; padding-bottom: 24px; text-align: center;" align="center"> <table width="100%" border="0" cellpadding="0" cellspacing="0" style="padding-bottom: 24px; padding-left: 24px; padding-right: 24px; text-align: center;" align="center">
<tr> <tr>
<td display="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; font-style: normal; font-weight: 400; font-size: 12px; line-height: 16px;"> <td display="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; font-style: normal; font-weight: 400; font-size: 12px; line-height: 16px;">
{{formatAdminOwnerEmails AdminOwnerEmails}} {{formatAdminOwnerEmails AdminOwnerEmails}}

View File

@ -55,6 +55,7 @@ public class OrganizationLicense : ILicense
UseSecretsManager = org.UseSecretsManager; UseSecretsManager = org.UseSecretsManager;
SmSeats = org.SmSeats; SmSeats = org.SmSeats;
SmServiceAccounts = org.SmServiceAccounts; SmServiceAccounts = org.SmServiceAccounts;
UseRiskInsights = org.UseRiskInsights;
// Deprecated. Left for backwards compatibility with old license versions. // Deprecated. Left for backwards compatibility with old license versions.
LimitCollectionCreationDeletion = org.LimitCollectionCreation || org.LimitCollectionDeletion; LimitCollectionCreationDeletion = org.LimitCollectionCreation || org.LimitCollectionDeletion;
@ -143,6 +144,7 @@ public class OrganizationLicense : ILicense
public bool UseSecretsManager { get; set; } public bool UseSecretsManager { get; set; }
public int? SmSeats { get; set; } public int? SmSeats { get; set; }
public int? SmServiceAccounts { get; set; } public int? SmServiceAccounts { get; set; }
public bool UseRiskInsights { get; set; }
// Deprecated. Left for backwards compatibility with old license versions. // Deprecated. Left for backwards compatibility with old license versions.
public bool LimitCollectionCreationDeletion { get; set; } = true; public bool LimitCollectionCreationDeletion { get; set; } = true;
@ -218,7 +220,8 @@ public class OrganizationLicense : ILicense
!p.Name.Equals(nameof(Issued)) && !p.Name.Equals(nameof(Issued)) &&
!p.Name.Equals(nameof(Refresh)) !p.Name.Equals(nameof(Refresh))
) )
)) ) &&
!p.Name.Equals(nameof(UseRiskInsights)))
.OrderBy(p => p.Name) .OrderBy(p => p.Name)
.Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}")
.Aggregate((c, n) => $"{c}|{n}"); .Aggregate((c, n) => $"{c}|{n}");

View File

@ -1,62 +0,0 @@
using Bit.Core.Billing;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Stripe;
using Plan = Bit.Core.Models.StaticStore.Plan;
namespace Bit.Core.Models.Business;
public class ProviderSubscriptionUpdate : SubscriptionUpdate
{
private readonly string _planId;
private readonly int _previouslyPurchasedSeats;
private readonly int _newlyPurchasedSeats;
protected override List<string> PlanIds => [_planId];
public ProviderSubscriptionUpdate(
Plan plan,
int previouslyPurchasedSeats,
int newlyPurchasedSeats)
{
if (!plan.Type.SupportsConsolidatedBilling())
{
throw new BillingException(
message: $"Cannot create a {nameof(ProviderSubscriptionUpdate)} for {nameof(PlanType)} that doesn't support consolidated billing");
}
_planId = plan.PasswordManager.StripeProviderPortalSeatPlanId;
_previouslyPurchasedSeats = previouslyPurchasedSeats;
_newlyPurchasedSeats = newlyPurchasedSeats;
}
public override List<SubscriptionItemOptions> RevertItemsOptions(Subscription subscription)
{
var subscriptionItem = FindSubscriptionItem(subscription, _planId);
return
[
new SubscriptionItemOptions
{
Id = subscriptionItem.Id,
Price = _planId,
Quantity = _previouslyPurchasedSeats
}
];
}
public override List<SubscriptionItemOptions> UpgradeItemsOptions(Subscription subscription)
{
var subscriptionItem = FindSubscriptionItem(subscription, _planId);
return
[
new SubscriptionItemOptions
{
Id = subscriptionItem.Id,
Price = _planId,
Quantity = _newlyPurchasedSeats
}
];
}
}

View File

@ -6,9 +6,7 @@ public class SecurityTaskNotificationViewModel : BaseMailModel
public int TaskCount { get; set; } public int TaskCount { get; set; }
public bool TaskCountPlural => TaskCount != 1; public List<string> AdminOwnerEmails { get; set; }
public IEnumerable<string> AdminOwnerEmails { get; set; }
public string ReviewPasswordsUrl => $"{WebVaultUrl}/browser-extension-prompt"; public string ReviewPasswordsUrl => $"{WebVaultUrl}/browser-extension-prompt";
} }

View File

@ -19,6 +19,7 @@ public class NotificationStatusDetails
public string? Body { get; set; } public string? Body { get; set; }
public DateTime CreationDate { get; set; } public DateTime CreationDate { get; set; }
public DateTime RevisionDate { get; set; } public DateTime RevisionDate { get; set; }
public Guid? TaskId { get; set; }
// Notification Status fields // Notification Status fields
public DateTime? ReadDate { get; set; } public DateTime? ReadDate { get; set; }
public DateTime? DeletedDate { get; set; } public DateTime? DeletedDate { get; set; }

View File

@ -1,5 +1,6 @@
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.KeyManagement.UserKey;
#nullable enable #nullable enable
@ -16,4 +17,5 @@ public interface IDeviceRepository : IRepository<Device, Guid>
// other requests. // other requests.
Task<ICollection<DeviceAuthDetails>> GetManyByUserIdWithDeviceAuth(Guid userId); Task<ICollection<DeviceAuthDetails>> GetManyByUserIdWithDeviceAuth(Guid userId);
Task ClearPushTokenAsync(Guid id); Task ClearPushTokenAsync(Guid id);
UpdateEncryptedDataForKeyRotation UpdateKeysForRotationAsync(Guid userId, IEnumerable<Device> devices);
} }

View File

@ -1,5 +1,4 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Api.Requests.Accounts; using Bit.Core.Billing.Models.Api.Requests.Accounts;
@ -26,11 +25,6 @@ public interface IPaymentService
int? newlyPurchasedAdditionalSecretsManagerServiceAccounts, int? newlyPurchasedAdditionalSecretsManagerServiceAccounts,
int newlyPurchasedAdditionalStorage); int newlyPurchasedAdditionalStorage);
Task<string> AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats); Task<string> AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats);
Task<string> AdjustSeats(
Provider provider,
Plan plan,
int currentlySubscribedSeats,
int newlySubscribedSeats);
Task<string> AdjustSmSeatsAsync(Organization organization, Plan plan, int additionalSeats); Task<string> AdjustSmSeatsAsync(Organization organization, Plan plan, int additionalSeats);
Task<string> AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId); Task<string> AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId);

View File

@ -1,5 +1,6 @@
using System.Net; using System.Net;
using System.Reflection; using System.Reflection;
using System.Text.Json;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Models.Mail; using Bit.Core.AdminConsole.Models.Mail;
@ -752,7 +753,21 @@ public class HandlebarsMailService : IMailService
return; return;
} }
var emailList = ((IEnumerable<string>)parameters[0]).ToList(); var emailList = new List<string>();
if (parameters[0] is JsonElement jsonElement && jsonElement.ValueKind == JsonValueKind.Array)
{
emailList = jsonElement.EnumerateArray().Select(e => e.GetString()).ToList();
}
else if (parameters[0] is IEnumerable<string> emails)
{
emailList = emails.ToList();
}
else
{
writer.WriteSafeString(string.Empty);
return;
}
if (emailList.Count == 0) if (emailList.Count == 0)
{ {
writer.WriteSafeString(string.Empty); writer.WriteSafeString(string.Empty);
@ -774,11 +789,34 @@ public class HandlebarsMailService : IMailService
{ {
outputMessage += string.Join(", ", emailList.Take(emailList.Count - 1) outputMessage += string.Join(", ", emailList.Take(emailList.Count - 1)
.Select(email => constructAnchorElement(email))); .Select(email => constructAnchorElement(email)));
outputMessage += $", and {constructAnchorElement(emailList.Last())}."; outputMessage += $" and {constructAnchorElement(emailList.Last())}.";
} }
writer.WriteSafeString($"{outputMessage}"); writer.WriteSafeString($"{outputMessage}");
}); });
// Returns the singular or plural form of a word based on the provided numeric value.
Handlebars.RegisterHelper("plurality", (writer, context, parameters) =>
{
if (parameters.Length != 3)
{
writer.WriteSafeString(string.Empty);
return;
}
var numeric = parameters[0];
var singularText = parameters[1].ToString();
var pluralText = parameters[2].ToString();
if (numeric is int number)
{
writer.WriteSafeString(number == 1 ? singularText : pluralText);
}
else
{
writer.WriteSafeString(string.Empty);
}
});
} }
public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token)
@ -1250,7 +1288,7 @@ public class HandlebarsMailService : IMailService
{ {
OrgName = CoreHelpers.SanitizeForEmail(sanitizedOrgName, false), OrgName = CoreHelpers.SanitizeForEmail(sanitizedOrgName, false),
TaskCount = notification.TaskCount, TaskCount = notification.TaskCount,
AdminOwnerEmails = adminOwnerEmails, AdminOwnerEmails = adminOwnerEmails.ToList(),
WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash,
}; };
message.Category = "SecurityTasksNotification"; message.Category = "SecurityTasksNotification";

View File

@ -1,5 +1,4 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
@ -10,6 +9,8 @@ using Bit.Core.Billing.Models.Api.Responses;
using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
@ -17,6 +18,7 @@ using Bit.Core.Models.BitStripe;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Stripe; using Stripe;
using PaymentMethod = Stripe.PaymentMethod; using PaymentMethod = Stripe.PaymentMethod;
@ -37,6 +39,8 @@ public class StripePaymentService : IPaymentService
private readonly ITaxService _taxService; private readonly ITaxService _taxService;
private readonly ISubscriberService _subscriberService; private readonly ISubscriberService _subscriberService;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
private readonly IAutomaticTaxFactory _automaticTaxFactory;
private readonly IAutomaticTaxStrategy _personalUseTaxStrategy;
public StripePaymentService( public StripePaymentService(
ITransactionRepository transactionRepository, ITransactionRepository transactionRepository,
@ -47,7 +51,9 @@ public class StripePaymentService : IPaymentService
IFeatureService featureService, IFeatureService featureService,
ITaxService taxService, ITaxService taxService,
ISubscriberService subscriberService, ISubscriberService subscriberService,
IPricingClient pricingClient) IPricingClient pricingClient,
IAutomaticTaxFactory automaticTaxFactory,
[FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy personalUseTaxStrategy)
{ {
_transactionRepository = transactionRepository; _transactionRepository = transactionRepository;
_logger = logger; _logger = logger;
@ -58,6 +64,8 @@ public class StripePaymentService : IPaymentService
_taxService = taxService; _taxService = taxService;
_subscriberService = subscriberService; _subscriberService = subscriberService;
_pricingClient = pricingClient; _pricingClient = pricingClient;
_automaticTaxFactory = automaticTaxFactory;
_personalUseTaxStrategy = personalUseTaxStrategy;
} }
private async Task ChangeOrganizationSponsorship( private async Task ChangeOrganizationSponsorship(
@ -92,9 +100,7 @@ public class StripePaymentService : IPaymentService
SubscriptionUpdate subscriptionUpdate, bool invoiceNow = false) SubscriptionUpdate subscriptionUpdate, bool invoiceNow = false)
{ {
// remember, when in doubt, throw // remember, when in doubt, throw
var subGetOptions = new SubscriptionGetOptions(); var subGetOptions = new SubscriptionGetOptions { Expand = ["customer.tax", "customer.tax_ids"] };
// subGetOptions.AddExpand("customer");
subGetOptions.AddExpand("customer.tax");
var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subGetOptions); var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subGetOptions);
if (sub == null) if (sub == null)
{ {
@ -125,7 +131,19 @@ public class StripePaymentService : IPaymentService
new SubscriptionPendingInvoiceItemIntervalOptions { Interval = "month" }; new SubscriptionPendingInvoiceItemIntervalOptions { Interval = "month" };
} }
subUpdateOptions.EnableAutomaticTax(sub.Customer, sub); if (subscriptionUpdate is CompleteSubscriptionUpdate)
{
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, updatedItemOptions.Select(x => x.Plan ?? x.Price));
var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters);
automaticTaxStrategy.SetUpdateOptions(subUpdateOptions, sub);
}
else
{
subUpdateOptions.EnableAutomaticTax(sub.Customer, sub);
}
}
if (!subscriptionUpdate.UpdateNeeded(sub)) if (!subscriptionUpdate.UpdateNeeded(sub))
{ {
@ -233,18 +251,6 @@ public class StripePaymentService : IPaymentService
public Task<string> AdjustSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats) => public Task<string> AdjustSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats) =>
FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats)); FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats));
public Task<string> AdjustSeats(
Provider provider,
StaticStore.Plan plan,
int currentlySubscribedSeats,
int newlySubscribedSeats)
=> FinalizeSubscriptionChangeAsync(
provider,
new ProviderSubscriptionUpdate(
plan,
currentlySubscribedSeats,
newlySubscribedSeats));
public Task<string> AdjustSmSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats) => public Task<string> AdjustSmSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats) =>
FinalizeSubscriptionChangeAsync( FinalizeSubscriptionChangeAsync(
organization, organization,
@ -812,21 +818,46 @@ public class StripePaymentService : IPaymentService
}); });
} }
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) && if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
customer.Subscriptions.Any(sub =>
sub.Id == subscriber.GatewaySubscriptionId &&
!sub.AutomaticTax.Enabled) &&
customer.HasTaxLocationVerified())
{ {
var subscriptionUpdateOptions = new SubscriptionUpdateOptions if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
{ {
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, var subscriptionGetOptions = new SubscriptionGetOptions
DefaultTaxRates = [] {
}; Expand = ["customer.tax", "customer.tax_ids"]
};
var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions);
_ = await _stripeAdapter.SubscriptionUpdateAsync( var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id));
subscriber.GatewaySubscriptionId, var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters);
subscriptionUpdateOptions); var subscriptionUpdateOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
if (subscriptionUpdateOptions != null)
{
_ = await _stripeAdapter.SubscriptionUpdateAsync(
subscriber.GatewaySubscriptionId,
subscriptionUpdateOptions);
}
}
}
else
{
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) &&
customer.Subscriptions.Any(sub =>
sub.Id == subscriber.GatewaySubscriptionId &&
!sub.AutomaticTax.Enabled) &&
customer.HasTaxLocationVerified())
{
var subscriptionUpdateOptions = new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
DefaultTaxRates = []
};
_ = await _stripeAdapter.SubscriptionUpdateAsync(
subscriber.GatewaySubscriptionId,
subscriptionUpdateOptions);
}
} }
} }
catch catch
@ -1228,6 +1259,8 @@ public class StripePaymentService : IPaymentService
} }
} }
_personalUseTaxStrategy.SetInvoiceCreatePreviewOptions(options);
try try
{ {
var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options);
@ -1270,10 +1303,6 @@ public class StripePaymentService : IPaymentService
var options = new InvoiceCreatePreviewOptions var options = new InvoiceCreatePreviewOptions
{ {
AutomaticTax = new InvoiceAutomaticTaxOptions
{
Enabled = true,
},
Currency = "usd", Currency = "usd",
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{ {
@ -1361,9 +1390,11 @@ public class StripePaymentService : IPaymentService
]; ];
} }
Customer gatewayCustomer = null;
if (!string.IsNullOrWhiteSpace(gatewayCustomerId)) if (!string.IsNullOrWhiteSpace(gatewayCustomerId))
{ {
var gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId);
if (gatewayCustomer.Discount != null) if (gatewayCustomer.Discount != null)
{ {
@ -1381,6 +1412,10 @@ public class StripePaymentService : IPaymentService
} }
} }
var automaticTaxFactoryParameters = new AutomaticTaxFactoryParameters(parameters.PasswordManager.Plan);
var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxFactoryParameters);
automaticTaxStrategy.SetInvoiceCreatePreviewOptions(options);
try try
{ {
var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options);

View File

@ -16,7 +16,11 @@ public static class DeviceTypes
DeviceType.LinuxDesktop, DeviceType.LinuxDesktop,
DeviceType.MacOsDesktop, DeviceType.MacOsDesktop,
DeviceType.WindowsDesktop, DeviceType.WindowsDesktop,
DeviceType.UWP, DeviceType.UWP
];
public static IReadOnlyCollection<DeviceType> CliTypes { get; } =
[
DeviceType.WindowsCLI, DeviceType.WindowsCLI,
DeviceType.MacOsCLI, DeviceType.MacOsCLI,
DeviceType.LinuxCLI DeviceType.LinuxCLI
@ -50,6 +54,7 @@ public static class DeviceTypes
{ {
not null when MobileTypes.Contains(deviceType.Value) => ClientType.Mobile, not null when MobileTypes.Contains(deviceType.Value) => ClientType.Mobile,
not null when DesktopTypes.Contains(deviceType.Value) => ClientType.Desktop, not null when DesktopTypes.Contains(deviceType.Value) => ClientType.Desktop,
not null when CliTypes.Contains(deviceType.Value) => ClientType.Cli,
not null when BrowserExtensionTypes.Contains(deviceType.Value) => ClientType.Browser, not null when BrowserExtensionTypes.Contains(deviceType.Value) => ClientType.Browser,
not null when BrowserTypes.Contains(deviceType.Value) => ClientType.Web, not null when BrowserTypes.Contains(deviceType.Value) => ClientType.Web,
_ => ClientType.All _ => ClientType.All

View File

@ -48,9 +48,16 @@ public class CreateManyTaskNotificationsCommand : ICreateManyTaskNotificationsCo
}).ToList(); }).ToList();
var organization = await _organizationRepository.GetByIdAsync(orgId); var organization = await _organizationRepository.GetByIdAsync(orgId);
var orgAdminEmails = await _organizationUserRepository.GetManyDetailsByRoleAsync(orgId, OrganizationUserType.Admin); var orgAdminEmails = (await _organizationUserRepository.GetManyDetailsByRoleAsync(orgId, OrganizationUserType.Admin))
var orgOwnerEmails = await _organizationUserRepository.GetManyDetailsByRoleAsync(orgId, OrganizationUserType.Owner); .Select(u => u.Email)
var orgAdminAndOwnerEmails = orgAdminEmails.Concat(orgOwnerEmails).Select(x => x.Email).Distinct().ToList(); .ToList();
var orgOwnerEmails = (await _organizationUserRepository.GetManyDetailsByRoleAsync(orgId, OrganizationUserType.Owner))
.Select(u => u.Email)
.ToList();
// Ensure proper deserialization of emails
var orgAdminAndOwnerEmails = orgAdminEmails.Concat(orgOwnerEmails).Distinct().ToList();
await _mailService.SendBulkSecurityTaskNotificationsAsync(organization, userTaskCount, orgAdminAndOwnerEmails); await _mailService.SendBulkSecurityTaskNotificationsAsync(organization, userTaskCount, orgAdminAndOwnerEmails);

View File

@ -15,7 +15,7 @@ public interface ICipherService
long requestLength, Guid savingUserId, bool orgAdmin = false); long requestLength, Guid savingUserId, bool orgAdmin = false);
Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, string fileName, string key, long requestLength, Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, string fileName, string key, long requestLength,
string attachmentId, Guid organizationShareId); string attachmentId, Guid organizationShareId);
Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); Task DeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false);
Task DeleteManyAsync(IEnumerable<Guid> cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); Task DeleteManyAsync(IEnumerable<Guid> cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false);
Task<DeleteAttachmentResponseData> DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, bool orgAdmin = false); Task<DeleteAttachmentResponseData> DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, bool orgAdmin = false);
Task PurgeAsync(Guid organizationId); Task PurgeAsync(Guid organizationId);
@ -27,9 +27,9 @@ public interface ICipherService
Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> ciphers, Guid organizationId, Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> ciphers, Guid organizationId,
IEnumerable<Guid> collectionIds, Guid sharingUserId); IEnumerable<Guid> collectionIds, Guid sharingUserId);
Task SaveCollectionsAsync(Cipher cipher, IEnumerable<Guid> collectionIds, Guid savingUserId, bool orgAdmin); Task SaveCollectionsAsync(Cipher cipher, IEnumerable<Guid> collectionIds, Guid savingUserId, bool orgAdmin);
Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); Task SoftDeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false);
Task SoftDeleteManyAsync(IEnumerable<Guid> cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); Task SoftDeleteManyAsync(IEnumerable<Guid> cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false);
Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false); Task RestoreAsync(CipherDetails cipherDetails, Guid restoringUserId, bool orgAdmin = false);
Task<ICollection<CipherOrganizationDetails>> RestoreManyAsync(IEnumerable<Guid> cipherIds, Guid restoringUserId, Guid? organizationId = null, bool orgAdmin = false); Task<ICollection<CipherOrganizationDetails>> RestoreManyAsync(IEnumerable<Guid> cipherIds, Guid restoringUserId, Guid? organizationId = null, bool orgAdmin = false);
Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId); Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId);
Task<AttachmentResponseData> GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId); Task<AttachmentResponseData> GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId);

View File

@ -14,6 +14,7 @@ using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business; using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Services; using Bit.Core.Tools.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Bit.Core.Vault.Authorization.Permissions;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
using Bit.Core.Vault.Enums; using Bit.Core.Vault.Enums;
using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Models.Data;
@ -44,6 +45,7 @@ public class CipherService : ICipherService
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly IGetCipherPermissionsForUserQuery _getCipherPermissionsForUserQuery; private readonly IGetCipherPermissionsForUserQuery _getCipherPermissionsForUserQuery;
private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IApplicationCacheService _applicationCacheService;
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
public CipherService( public CipherService(
@ -64,6 +66,7 @@ public class CipherService : ICipherService
ICurrentContext currentContext, ICurrentContext currentContext,
IGetCipherPermissionsForUserQuery getCipherPermissionsForUserQuery, IGetCipherPermissionsForUserQuery getCipherPermissionsForUserQuery,
IPolicyRequirementQuery policyRequirementQuery, IPolicyRequirementQuery policyRequirementQuery,
IApplicationCacheService applicationCacheService,
IFeatureService featureService) IFeatureService featureService)
{ {
_cipherRepository = cipherRepository; _cipherRepository = cipherRepository;
@ -83,6 +86,7 @@ public class CipherService : ICipherService
_currentContext = currentContext; _currentContext = currentContext;
_getCipherPermissionsForUserQuery = getCipherPermissionsForUserQuery; _getCipherPermissionsForUserQuery = getCipherPermissionsForUserQuery;
_policyRequirementQuery = policyRequirementQuery; _policyRequirementQuery = policyRequirementQuery;
_applicationCacheService = applicationCacheService;
_featureService = featureService; _featureService = featureService;
} }
@ -421,19 +425,19 @@ public class CipherService : ICipherService
return response; return response;
} }
public async Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) public async Task DeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false)
{ {
if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) if (!orgAdmin && !await UserCanDeleteAsync(cipherDetails, deletingUserId))
{ {
throw new BadRequestException("You do not have permissions to delete this."); throw new BadRequestException("You do not have permissions to delete this.");
} }
await _cipherRepository.DeleteAsync(cipher); await _cipherRepository.DeleteAsync(cipherDetails);
await _attachmentStorageService.DeleteAttachmentsForCipherAsync(cipher.Id); await _attachmentStorageService.DeleteAttachmentsForCipherAsync(cipherDetails.Id);
await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Deleted); await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_Deleted);
// push // push
await _pushService.PushSyncCipherDeleteAsync(cipher); await _pushService.PushSyncCipherDeleteAsync(cipherDetails);
} }
public async Task DeleteManyAsync(IEnumerable<Guid> cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false) public async Task DeleteManyAsync(IEnumerable<Guid> cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false)
@ -450,8 +454,8 @@ public class CipherService : ICipherService
else else
{ {
var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId);
deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); var filteredCiphers = await FilterCiphersByDeletePermission(ciphers, cipherIdsSet, deletingUserId);
deletingCiphers = filteredCiphers.Select(c => (Cipher)c).ToList();
await _cipherRepository.DeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); await _cipherRepository.DeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId);
} }
@ -703,33 +707,26 @@ public class CipherService : ICipherService
await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds);
} }
public async Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) public async Task SoftDeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false)
{ {
if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) if (!orgAdmin && !await UserCanDeleteAsync(cipherDetails, deletingUserId))
{ {
throw new BadRequestException("You do not have permissions to soft delete this."); throw new BadRequestException("You do not have permissions to soft delete this.");
} }
if (cipher.DeletedDate.HasValue) if (cipherDetails.DeletedDate.HasValue)
{ {
// Already soft-deleted, we can safely ignore this // Already soft-deleted, we can safely ignore this
return; return;
} }
cipher.DeletedDate = cipher.RevisionDate = DateTime.UtcNow; cipherDetails.DeletedDate = cipherDetails.RevisionDate = DateTime.UtcNow;
if (cipher is CipherDetails details) await _cipherRepository.UpsertAsync(cipherDetails);
{ await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted);
await _cipherRepository.UpsertAsync(details);
}
else
{
await _cipherRepository.UpsertAsync(cipher);
}
await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_SoftDeleted);
// push // push
await _pushService.PushSyncCipherUpdateAsync(cipher, null); await _pushService.PushSyncCipherUpdateAsync(cipherDetails, null);
} }
public async Task SoftDeleteManyAsync(IEnumerable<Guid> cipherIds, Guid deletingUserId, Guid? organizationId, bool orgAdmin) public async Task SoftDeleteManyAsync(IEnumerable<Guid> cipherIds, Guid deletingUserId, Guid? organizationId, bool orgAdmin)
@ -746,8 +743,8 @@ public class CipherService : ICipherService
else else
{ {
var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId);
deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); var filteredCiphers = await FilterCiphersByDeletePermission(ciphers, cipherIdsSet, deletingUserId);
deletingCiphers = filteredCiphers.Select(c => (Cipher)c).ToList();
await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId);
} }
@ -762,34 +759,27 @@ public class CipherService : ICipherService
await _pushService.PushSyncCiphersAsync(deletingUserId); await _pushService.PushSyncCiphersAsync(deletingUserId);
} }
public async Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false) public async Task RestoreAsync(CipherDetails cipherDetails, Guid restoringUserId, bool orgAdmin = false)
{ {
if (!orgAdmin && !(await UserCanEditAsync(cipher, restoringUserId))) if (!orgAdmin && !await UserCanRestoreAsync(cipherDetails, restoringUserId))
{ {
throw new BadRequestException("You do not have permissions to delete this."); throw new BadRequestException("You do not have permissions to delete this.");
} }
if (!cipher.DeletedDate.HasValue) if (!cipherDetails.DeletedDate.HasValue)
{ {
// Already restored, we can safely ignore this // Already restored, we can safely ignore this
return; return;
} }
cipher.DeletedDate = null; cipherDetails.DeletedDate = null;
cipher.RevisionDate = DateTime.UtcNow; cipherDetails.RevisionDate = DateTime.UtcNow;
if (cipher is CipherDetails details) await _cipherRepository.UpsertAsync(cipherDetails);
{ await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_Restored);
await _cipherRepository.UpsertAsync(details);
}
else
{
await _cipherRepository.UpsertAsync(cipher);
}
await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Restored);
// push // push
await _pushService.PushSyncCipherUpdateAsync(cipher, null); await _pushService.PushSyncCipherUpdateAsync(cipherDetails, null);
} }
public async Task<ICollection<CipherOrganizationDetails>> RestoreManyAsync(IEnumerable<Guid> cipherIds, Guid restoringUserId, Guid? organizationId = null, bool orgAdmin = false) public async Task<ICollection<CipherOrganizationDetails>> RestoreManyAsync(IEnumerable<Guid> cipherIds, Guid restoringUserId, Guid? organizationId = null, bool orgAdmin = false)
@ -812,8 +802,8 @@ public class CipherService : ICipherService
else else
{ {
var ciphers = await _cipherRepository.GetManyByUserIdAsync(restoringUserId); var ciphers = await _cipherRepository.GetManyByUserIdAsync(restoringUserId);
restoringCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(c => (CipherOrganizationDetails)c).ToList(); var filteredCiphers = await FilterCiphersByDeletePermission(ciphers, cipherIdsSet, restoringUserId);
restoringCiphers = filteredCiphers.Select(c => (CipherOrganizationDetails)c).ToList();
revisionDate = await _cipherRepository.RestoreAsync(restoringCiphers.Select(c => c.Id), restoringUserId); revisionDate = await _cipherRepository.RestoreAsync(restoringCiphers.Select(c => c.Id), restoringUserId);
} }
@ -844,6 +834,34 @@ public class CipherService : ICipherService
return await _cipherRepository.GetCanEditByIdAsync(userId, cipher.Id); return await _cipherRepository.GetCanEditByIdAsync(userId, cipher.Id);
} }
private async Task<bool> UserCanDeleteAsync(CipherDetails cipher, Guid userId)
{
if (!_featureService.IsEnabled(FeatureFlagKeys.LimitItemDeletion))
{
return await UserCanEditAsync(cipher, userId);
}
var user = await _userService.GetUserByIdAsync(userId);
var organizationAbility = cipher.OrganizationId.HasValue ?
await _applicationCacheService.GetOrganizationAbilityAsync(cipher.OrganizationId.Value) : null;
return NormalCipherPermissions.CanDelete(user, cipher, organizationAbility);
}
private async Task<bool> UserCanRestoreAsync(CipherDetails cipher, Guid userId)
{
if (!_featureService.IsEnabled(FeatureFlagKeys.LimitItemDeletion))
{
return await UserCanEditAsync(cipher, userId);
}
var user = await _userService.GetUserByIdAsync(userId);
var organizationAbility = cipher.OrganizationId.HasValue ?
await _applicationCacheService.GetOrganizationAbilityAsync(cipher.OrganizationId.Value) : null;
return NormalCipherPermissions.CanRestore(user, cipher, organizationAbility);
}
private void ValidateCipherLastKnownRevisionDateAsync(Cipher cipher, DateTime? lastKnownRevisionDate) private void ValidateCipherLastKnownRevisionDateAsync(Cipher cipher, DateTime? lastKnownRevisionDate)
{ {
if (cipher.Id == default || !lastKnownRevisionDate.HasValue) if (cipher.Id == default || !lastKnownRevisionDate.HasValue)
@ -1010,4 +1028,35 @@ public class CipherService : ICipherService
cipher.Data = JsonSerializer.Serialize(newCipherData); cipher.Data = JsonSerializer.Serialize(newCipherData);
} }
} }
// This method is used to filter ciphers based on the user's permissions to delete them.
// It supports both the old and new logic depending on the feature flag.
private async Task<List<T>> FilterCiphersByDeletePermission<T>(
IEnumerable<T> ciphers,
HashSet<Guid> cipherIdsSet,
Guid userId) where T : CipherDetails
{
if (!_featureService.IsEnabled(FeatureFlagKeys.LimitItemDeletion))
{
return ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).ToList();
}
var user = await _userService.GetUserByIdAsync(userId);
var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var filteredCiphers = ciphers
.Where(c => cipherIdsSet.Contains(c.Id))
.GroupBy(c => c.OrganizationId)
.SelectMany(group =>
{
var organizationAbility = group.Key.HasValue &&
organizationAbilities.TryGetValue(group.Key.Value, out var ability) ?
ability : null;
return group.Where(c => NormalCipherPermissions.CanDelete(user, c, organizationAbility));
})
.ToList();
return filteredCiphers;
}
} }

View File

@ -1,8 +1,10 @@
using System.Data; using System.Data;
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities;
using Dapper; using Dapper;
using Microsoft.Data.SqlClient; using Microsoft.Data.SqlClient;
@ -109,4 +111,35 @@ public class DeviceRepository : Repository<Device, Guid>, IDeviceRepository
commandType: CommandType.StoredProcedure); commandType: CommandType.StoredProcedure);
} }
} }
public UpdateEncryptedDataForKeyRotation UpdateKeysForRotationAsync(Guid userId, IEnumerable<Device> devices)
{
return async (SqlConnection connection, SqlTransaction transaction) =>
{
const string sql = @"
UPDATE D
SET
D.[EncryptedPublicKey] = UD.[encryptedPublicKey],
D.[EncryptedUserKey] = UD.[encryptedUserKey]
FROM
[dbo].[Device] D
INNER JOIN
OPENJSON(@DeviceCredentials)
WITH (
id UNIQUEIDENTIFIER,
encryptedPublicKey NVARCHAR(MAX),
encryptedUserKey NVARCHAR(MAX)
) UD
ON UD.[id] = D.[Id]
WHERE
D.[UserId] = @UserId";
var deviceCredentials = CoreHelpers.ClassToJsonData(devices);
await connection.ExecuteAsync(
sql,
new { UserId = userId, DeviceCredentials = deviceCredentials },
transaction: transaction,
commandType: CommandType.Text);
};
}
} }

View File

@ -68,12 +68,11 @@ public class WebAuthnCredentialRepository : Repository<Core.Auth.Entities.WebAut
var newCreds = credentials.ToList(); var newCreds = credentials.ToList();
using var scope = ServiceScopeFactory.CreateScope(); using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope); var dbContext = GetDatabaseContext(scope);
var userWebauthnCredentials = await GetDbSet(dbContext)
.Where(wc => wc.Id == wc.Id) var newCredIds = newCreds.Select(nwc => nwc.Id).ToList();
var validUserWebauthnCredentials = await GetDbSet(dbContext)
.Where(wc => wc.UserId == userId && newCredIds.Contains(wc.Id))
.ToListAsync(); .ToListAsync();
var validUserWebauthnCredentials = userWebauthnCredentials
.Where(wc => newCreds.Any(nwc => nwc.Id == wc.Id))
.Where(wc => wc.UserId == userId);
foreach (var wc in validUserWebauthnCredentials) foreach (var wc in validUserWebauthnCredentials)
{ {

View File

@ -52,6 +52,7 @@ public class NotificationStatusDetailsViewQuery(Guid userId, ClientType clientTy
ClientType = x.n.ClientType, ClientType = x.n.ClientType,
UserId = x.n.UserId, UserId = x.n.UserId,
OrganizationId = x.n.OrganizationId, OrganizationId = x.n.OrganizationId,
TaskId = x.n.TaskId,
Title = x.n.Title, Title = x.n.Title,
Body = x.n.Body, Body = x.n.Body,
CreationDate = x.n.CreationDate, CreationDate = x.n.CreationDate,

View File

@ -1,5 +1,6 @@
using AutoMapper; using AutoMapper;
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Infrastructure.EntityFramework.Auth.Repositories.Queries; using Bit.Infrastructure.EntityFramework.Auth.Repositories.Queries;
@ -91,4 +92,30 @@ public class DeviceRepository : Repository<Core.Entities.Device, Device, Guid>,
return await query.GetQuery(dbContext, userId, expirationMinutes).ToListAsync(); return await query.GetQuery(dbContext, userId, expirationMinutes).ToListAsync();
} }
} }
public UpdateEncryptedDataForKeyRotation UpdateKeysForRotationAsync(Guid userId, IEnumerable<Core.Entities.Device> devices)
{
return async (_, _) =>
{
var deviceUpdates = devices.ToList();
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
var userDevices = await GetDbSet(dbContext)
.Where(device => device.UserId == userId)
.ToListAsync();
var userDevicesWithUpdatesPending = userDevices
.Where(existingDevice => deviceUpdates.Any(updatedDevice => updatedDevice.Id == existingDevice.Id))
.ToList();
foreach (var deviceToUpdate in userDevicesWithUpdatesPending)
{
var deviceUpdate = deviceUpdates.First(deviceUpdate => deviceUpdate.Id == deviceToUpdate.Id);
deviceToUpdate.EncryptedPublicKey = deviceUpdate.EncryptedPublicKey;
deviceToUpdate.EncryptedUserKey = deviceUpdate.EncryptedUserKey;
}
await dbContext.SaveChangesAsync();
};
}
} }

View File

@ -135,6 +135,11 @@ public static class HubHelpers
} }
break; break;
case PushType.PendingSecurityTasks:
var pendingTasksData = JsonSerializer.Deserialize<PushNotificationData<UserPushNotification>>(notificationJson, _deserializerOptions);
await hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString())
.SendAsync(_receiveMessageMethod, pendingTasksData, cancellationToken);
break;
default: default:
break; break;
} }

View File

@ -1,10 +1,20 @@
CREATE VIEW [dbo].[NotificationStatusDetailsView] CREATE VIEW [dbo].[NotificationStatusDetailsView]
AS AS
SELECT SELECT
N.*, N.[Id],
NS.UserId AS NotificationStatusUserId, N.[Priority],
NS.ReadDate, N.[Global],
NS.DeletedDate N.[ClientType],
N.[UserId],
N.[OrganizationId],
N.[Title],
N.[Body],
N.[CreationDate],
N.[RevisionDate],
N.[TaskId],
NS.[UserId] AS [NotificationStatusUserId],
NS.[ReadDate],
NS.[DeletedDate]
FROM FROM
[dbo].[Notification] AS N [dbo].[Notification] AS N
LEFT JOIN LEFT JOIN

View File

@ -29,6 +29,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture<ApiApplication
private readonly ApiApplicationFactory _factory; private readonly ApiApplicationFactory _factory;
private readonly LoginHelper _loginHelper; private readonly LoginHelper _loginHelper;
private readonly IUserRepository _userRepository; private readonly IUserRepository _userRepository;
private readonly IDeviceRepository _deviceRepository;
private readonly IPasswordHasher<User> _passwordHasher; private readonly IPasswordHasher<User> _passwordHasher;
private string _ownerEmail = null!; private string _ownerEmail = null!;
@ -40,6 +41,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture<ApiApplication
_client = factory.CreateClient(); _client = factory.CreateClient();
_loginHelper = new LoginHelper(_factory, _client); _loginHelper = new LoginHelper(_factory, _client);
_userRepository = _factory.GetService<IUserRepository>(); _userRepository = _factory.GetService<IUserRepository>();
_deviceRepository = _factory.GetService<IDeviceRepository>();
_emergencyAccessRepository = _factory.GetService<IEmergencyAccessRepository>(); _emergencyAccessRepository = _factory.GetService<IEmergencyAccessRepository>();
_organizationUserRepository = _factory.GetService<IOrganizationUserRepository>(); _organizationUserRepository = _factory.GetService<IOrganizationUserRepository>();
_passwordHasher = _factory.GetService<IPasswordHasher<User>>(); _passwordHasher = _factory.GetService<IPasswordHasher<User>>();
@ -238,10 +240,12 @@ public class AccountsKeyManagementControllerTests : IClassFixture<ApiApplication
]; ];
request.AccountUnlockData.MasterPasswordUnlockData.MasterKeyEncryptedUserKey = _mockEncryptedString; request.AccountUnlockData.MasterPasswordUnlockData.MasterKeyEncryptedUserKey = _mockEncryptedString;
request.AccountUnlockData.PasskeyUnlockData = []; request.AccountUnlockData.PasskeyUnlockData = [];
request.AccountUnlockData.DeviceKeyUnlockData = [];
request.AccountUnlockData.EmergencyAccessUnlockData = []; request.AccountUnlockData.EmergencyAccessUnlockData = [];
request.AccountUnlockData.OrganizationAccountRecoveryUnlockData = []; request.AccountUnlockData.OrganizationAccountRecoveryUnlockData = [];
var response = await _client.PostAsJsonAsync("/accounts/key-management/rotate-user-account-keys", request); var response = await _client.PostAsJsonAsync("/accounts/key-management/rotate-user-account-keys", request);
var responseMessage = await response.Content.ReadAsStringAsync();
response.EnsureSuccessStatusCode(); response.EnsureSuccessStatusCode();
var userNewState = await _userRepository.GetByEmailAsync(_ownerEmail); var userNewState = await _userRepository.GetByEmailAsync(_ownerEmail);

View File

@ -10,14 +10,19 @@ namespace Bit.Api.Test.AdminConsole.Models.Response.Helpers;
public class PolicyDetailResponsesTests public class PolicyDetailResponsesTests
{ {
[Fact] [Theory]
public async Task GetSingleOrgPolicyDetailResponseAsync_GivenPolicyEntity_WhenIsSingleOrgTypeAndHasVerifiedDomains_ThenShouldNotBeAbleToToggle() [InlineData(true, false)]
[InlineData(false, true)]
public async Task GetSingleOrgPolicyDetailResponseAsync_WhenIsSingleOrgTypeAndHasVerifiedDomains_ShouldReturnExpectedToggleState(
bool policyEnabled,
bool expectedCanToggle)
{ {
var fixture = new Fixture(); var fixture = new Fixture();
var policy = fixture.Build<Policy>() var policy = fixture.Build<Policy>()
.Without(p => p.Data) .Without(p => p.Data)
.With(p => p.Type, PolicyType.SingleOrg) .With(p => p.Type, PolicyType.SingleOrg)
.With(p => p.Enabled, policyEnabled)
.Create(); .Create();
var querySub = Substitute.For<IOrganizationHasVerifiedDomainsQuery>(); var querySub = Substitute.For<IOrganizationHasVerifiedDomainsQuery>();
@ -26,11 +31,11 @@ public class PolicyDetailResponsesTests
var result = await policy.GetSingleOrgPolicyDetailResponseAsync(querySub); var result = await policy.GetSingleOrgPolicyDetailResponseAsync(querySub);
Assert.False(result.CanToggleState); Assert.Equal(expectedCanToggle, result.CanToggleState);
} }
[Fact] [Fact]
public async Task GetSingleOrgPolicyDetailResponseAsync_GivenPolicyEntity_WhenIsNotSingleOrgType_ThenShouldThrowArgumentException() public async Task GetSingleOrgPolicyDetailResponseAsync_WhenIsNotSingleOrgType_ThenShouldThrowArgumentException()
{ {
var fixture = new Fixture(); var fixture = new Fixture();
@ -49,7 +54,7 @@ public class PolicyDetailResponsesTests
} }
[Fact] [Fact]
public async Task GetSingleOrgPolicyDetailResponseAsync_GivenPolicyEntity_WhenIsSingleOrgTypeAndDoesNotHaveVerifiedDomains_ThenShouldBeAbleToToggle() public async Task GetSingleOrgPolicyDetailResponseAsync_WhenIsSingleOrgTypeAndDoesNotHaveVerifiedDomains_ThenShouldBeAbleToToggle()
{ {
var fixture = new Fixture(); var fixture = new Fixture();

View File

@ -0,0 +1,49 @@
using Bit.Api.KeyManagement.Validators;
using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Api.Test.KeyManagement.Validators;
[SutProviderCustomize]
public class DeviceRotationValidatorTests
{
[Theory, BitAutoData]
public async Task ValidateAsync_SentDevicesAreEmptyButDatabaseDevicesAreNot_Throws(
SutProvider<DeviceRotationValidator> sutProvider, User user, IEnumerable<OtherDeviceKeysUpdateRequestModel> devices)
{
var userCiphers = devices.Select(c => new Device { Id = c.DeviceId, EncryptedPrivateKey = "EncryptedPrivateKey", EncryptedPublicKey = "EncryptedPublicKey", EncryptedUserKey = "EncryptedUserKey" }).ToList();
sutProvider.GetDependency<IDeviceRepository>().GetManyByUserIdAsync(user.Id)
.Returns(userCiphers);
await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.ValidateAsync(user, Enumerable.Empty<OtherDeviceKeysUpdateRequestModel>()));
}
[Theory, BitAutoData]
public async Task ValidateAsync_SentDevicesTrustedButDatabaseUntrusted_Throws(
SutProvider<DeviceRotationValidator> sutProvider, User user, IEnumerable<OtherDeviceKeysUpdateRequestModel> devices)
{
var userCiphers = devices.Select(c => new Device { Id = c.DeviceId, EncryptedPrivateKey = "Key", EncryptedPublicKey = "Key", EncryptedUserKey = "Key" }).ToList();
sutProvider.GetDependency<IDeviceRepository>().GetManyByUserIdAsync(user.Id)
.Returns(userCiphers);
await Assert.ThrowsAsync<BadRequestException>(async () => await sutProvider.Sut.ValidateAsync(user, [
new OtherDeviceKeysUpdateRequestModel { DeviceId = userCiphers.First().Id, EncryptedPublicKey = null, EncryptedUserKey = null }
]));
}
[Theory, BitAutoData]
public async Task ValidateAsync_Validates(
SutProvider<DeviceRotationValidator> sutProvider, User user, IEnumerable<OtherDeviceKeysUpdateRequestModel> devices)
{
var userCiphers = devices.Select(c => new Device { Id = c.DeviceId, EncryptedPrivateKey = "Key", EncryptedPublicKey = "Key", EncryptedUserKey = "Key" }).ToList().Slice(0, 1);
sutProvider.GetDependency<IDeviceRepository>().GetManyByUserIdAsync(user.Id)
.Returns(userCiphers);
Assert.NotEmpty(await sutProvider.Sut.ValidateAsync(user, [
new OtherDeviceKeysUpdateRequestModel { DeviceId = userCiphers.First().Id, EncryptedPublicKey = "Key", EncryptedUserKey = "Key" }
]));
}
}

View File

@ -67,6 +67,7 @@ public class NotificationsControllerTests
Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date); Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date);
Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate); Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate);
Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate); Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate);
Assert.Equal(expectedNotificationStatusDetails.TaskId, notificationResponseModel.TaskId);
}); });
Assert.Null(listResponse.ContinuationToken); Assert.Null(listResponse.ContinuationToken);
@ -116,6 +117,7 @@ public class NotificationsControllerTests
Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date); Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date);
Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate); Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate);
Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate); Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate);
Assert.Equal(expectedNotificationStatusDetails.TaskId, notificationResponseModel.TaskId);
}); });
Assert.Equal("2", listResponse.ContinuationToken); Assert.Equal("2", listResponse.ContinuationToken);
@ -164,6 +166,7 @@ public class NotificationsControllerTests
Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date); Assert.Equal(expectedNotificationStatusDetails.RevisionDate, notificationResponseModel.Date);
Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate); Assert.Equal(expectedNotificationStatusDetails.ReadDate, notificationResponseModel.ReadDate);
Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate); Assert.Equal(expectedNotificationStatusDetails.DeletedDate, notificationResponseModel.DeletedDate);
Assert.Equal(expectedNotificationStatusDetails.TaskId, notificationResponseModel.TaskId);
}); });
Assert.Null(listResponse.ContinuationToken); Assert.Null(listResponse.ContinuationToken);

View File

@ -26,6 +26,7 @@ public class NotificationResponseModelTests
ClientType = ClientType.All, ClientType = ClientType.All,
Title = "Test Title", Title = "Test Title",
Body = "Test Body", Body = "Test Body",
TaskId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow - TimeSpan.FromMinutes(3), RevisionDate = DateTime.UtcNow - TimeSpan.FromMinutes(3),
ReadDate = DateTime.UtcNow - TimeSpan.FromMinutes(1), ReadDate = DateTime.UtcNow - TimeSpan.FromMinutes(1),
DeletedDate = DateTime.UtcNow, DeletedDate = DateTime.UtcNow,
@ -39,5 +40,6 @@ public class NotificationResponseModelTests
Assert.Equal(model.Date, notificationStatusDetails.RevisionDate); Assert.Equal(model.Date, notificationStatusDetails.RevisionDate);
Assert.Equal(model.ReadDate, notificationStatusDetails.ReadDate); Assert.Equal(model.ReadDate, notificationStatusDetails.ReadDate);
Assert.Equal(model.DeletedDate, notificationStatusDetails.DeletedDate); Assert.Equal(model.DeletedDate, notificationStatusDetails.DeletedDate);
Assert.Equal(model.TaskId, notificationStatusDetails.TaskId);
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@ using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
using Bit.Test.Common.Helpers; using Bit.Test.Common.Helpers;
using Microsoft.Extensions.Logging;
using NSubstitute; using NSubstitute;
using Xunit; using Xunit;
using GlobalSettings = Bit.Core.Settings.GlobalSettings; using GlobalSettings = Bit.Core.Settings.GlobalSettings;
@ -273,78 +274,7 @@ public class AuthRequestServiceTests
/// each of them. /// each of them.
/// </summary> /// </summary>
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task CreateAuthRequestAsync_AdminApproval_CreatesForEachOrganization( public async Task CreateAuthRequestAsync_AdminApproval_CreatesForEachOrganization_SendsEmails(
SutProvider<AuthRequestService> sutProvider,
AuthRequestCreateRequestModel createModel,
User user,
OrganizationUser organizationUser1,
OrganizationUser organizationUser2)
{
createModel.Type = AuthRequestType.AdminApproval;
user.Email = createModel.Email;
organizationUser1.UserId = user.Id;
organizationUser2.UserId = user.Id;
sutProvider.GetDependency<IUserRepository>()
.GetByEmailAsync(user.Email)
.Returns(user);
sutProvider.GetDependency<ICurrentContext>()
.DeviceType
.Returns(DeviceType.ChromeExtension);
sutProvider.GetDependency<ICurrentContext>()
.UserId
.Returns(user.Id);
sutProvider.GetDependency<IGlobalSettings>()
.PasswordlessAuth.KnownDevicesOnly
.Returns(false);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByUserAsync(user.Id)
.Returns(new List<OrganizationUser>
{
organizationUser1,
organizationUser2,
});
sutProvider.GetDependency<IAuthRequestRepository>()
.CreateAsync(Arg.Any<AuthRequest>())
.Returns(c => c.ArgAt<AuthRequest>(0));
var authRequest = await sutProvider.Sut.CreateAuthRequestAsync(createModel);
Assert.Equal(organizationUser1.OrganizationId, authRequest.OrganizationId);
await sutProvider.GetDependency<IAuthRequestRepository>()
.Received(1)
.CreateAsync(Arg.Is<AuthRequest>(o => o.OrganizationId == organizationUser1.OrganizationId));
await sutProvider.GetDependency<IAuthRequestRepository>()
.Received(1)
.CreateAsync(Arg.Is<AuthRequest>(o => o.OrganizationId == organizationUser2.OrganizationId));
await sutProvider.GetDependency<IAuthRequestRepository>()
.Received(2)
.CreateAsync(Arg.Any<AuthRequest>());
await sutProvider.GetDependency<IEventService>()
.Received(1)
.LogUserEventAsync(user.Id, EventType.User_RequestedDeviceApproval);
await sutProvider.GetDependency<IMailService>()
.DidNotReceiveWithAnyArgs()
.SendDeviceApprovalRequestedNotificationEmailAsync(
Arg.Any<IEnumerable<string>>(),
Arg.Any<Guid>(),
Arg.Any<string>(),
Arg.Any<string>());
}
[Theory, BitAutoData]
public async Task CreateAuthRequestAsync_AdminApproval_WithAdminNotifications_CreatesForEachOrganization_SendsEmails(
SutProvider<AuthRequestService> sutProvider, SutProvider<AuthRequestService> sutProvider,
AuthRequestCreateRequestModel createModel, AuthRequestCreateRequestModel createModel,
User user, User user,
@ -369,10 +299,6 @@ public class AuthRequestServiceTests
ManageResetPassword = true, ManageResetPassword = true,
}); });
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DeviceApprovalRequestAdminNotifications)
.Returns(true);
sutProvider.GetDependency<IUserRepository>() sutProvider.GetDependency<IUserRepository>()
.GetByEmailAsync(user.Email) .GetByEmailAsync(user.Email)
.Returns(user); .Returns(user);
@ -470,6 +396,87 @@ public class AuthRequestServiceTests
user.Name); user.Name);
} }
[Theory, BitAutoData]
public async Task CreateAuthRequestAsync_AdminApproval_WithAdminNotifications_AndNoAdminEmails_ShouldNotSendNotificationEmails(
SutProvider<AuthRequestService> sutProvider,
AuthRequestCreateRequestModel createModel,
User user,
OrganizationUser organizationUser1)
{
createModel.Type = AuthRequestType.AdminApproval;
user.Email = createModel.Email;
organizationUser1.UserId = user.Id;
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.DeviceApprovalRequestAdminNotifications)
.Returns(true);
sutProvider.GetDependency<IUserRepository>()
.GetByEmailAsync(user.Email)
.Returns(user);
sutProvider.GetDependency<ICurrentContext>()
.DeviceType
.Returns(DeviceType.ChromeExtension);
sutProvider.GetDependency<ICurrentContext>()
.UserId
.Returns(user.Id);
sutProvider.GetDependency<IGlobalSettings>()
.PasswordlessAuth.KnownDevicesOnly
.Returns(false);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByUserAsync(user.Id)
.Returns(new List<OrganizationUser>
{
organizationUser1,
});
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyByMinimumRoleAsync(organizationUser1.OrganizationId, OrganizationUserType.Admin)
.Returns([]);
sutProvider.GetDependency<IOrganizationUserRepository>()
.GetManyDetailsByRoleAsync(organizationUser1.OrganizationId, OrganizationUserType.Custom)
.Returns([]);
sutProvider.GetDependency<IAuthRequestRepository>()
.CreateAsync(Arg.Any<AuthRequest>())
.Returns(c => c.ArgAt<AuthRequest>(0));
var authRequest = await sutProvider.Sut.CreateAuthRequestAsync(createModel);
Assert.Equal(organizationUser1.OrganizationId, authRequest.OrganizationId);
await sutProvider.GetDependency<IAuthRequestRepository>()
.Received(1)
.CreateAsync(Arg.Is<AuthRequest>(o => o.OrganizationId == organizationUser1.OrganizationId));
await sutProvider.GetDependency<IAuthRequestRepository>()
.Received(1)
.CreateAsync(Arg.Any<AuthRequest>());
await sutProvider.GetDependency<IEventService>()
.Received(1)
.LogUserEventAsync(user.Id, EventType.User_RequestedDeviceApproval);
await sutProvider.GetDependency<IMailService>()
.Received(0)
.SendDeviceApprovalRequestedNotificationEmailAsync(
Arg.Any<IEnumerable<string>>(),
Arg.Any<Guid>(),
Arg.Any<string>(),
Arg.Any<string>());
var expectedLogMessage = "There are no admin emails to send to.";
sutProvider.GetDependency<ILogger<AuthRequestService>>()
.Received(1)
.LogWarning(expectedLogMessage);
}
/// <summary> /// <summary>
/// Story: When an <see cref="AuthRequest"> is approved we want to update it in the database so it cannot have /// Story: When an <see cref="AuthRequest"> is approved we want to update it in the database so it cannot have
/// it's status changed again and we want to push a notification to let the user know of the approval. /// it's status changed again and we want to push a notification to let the user know of the approval.

View File

@ -0,0 +1,492 @@
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Stripe;
using Xunit;
namespace Bit.Core.Test.Billing.Services.Implementations.AutomaticTax;
[SutProviderCustomize]
public class BusinessUseAutomaticTaxStrategyTests
{
[Theory]
[BitAutoData]
public void GetUpdateOptions_ReturnsNull_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription();
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(false);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.Null(actual);
}
[Theory]
[BitAutoData]
public void GetUpdateOptions_ReturnsNull_WhenSubscriptionDoesNotNeedUpdating(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = true
},
Customer = new Customer
{
Address = new Address
{
Country = "US",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.Null(actual);
}
[Theory]
[BitAutoData]
public void GetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = true
},
Customer = new Customer
{
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.NotNull(actual);
Assert.False(actual.AutomaticTax.Enabled);
}
[Theory]
[BitAutoData]
public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForAmericanCustomers(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = false
},
Customer = new Customer
{
Address = new Address
{
Country = "US",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.NotNull(actual);
Assert.True(actual.AutomaticTax.Enabled);
}
[Theory]
[BitAutoData]
public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = false
},
Customer = new Customer
{
Address = new Address
{
Country = "ES",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
},
TaxIds = new StripeList<TaxId>
{
Data = new List<TaxId>
{
new()
{
Country = "ES",
Type = "eu_vat",
Value = "ESZ8880999Z"
}
}
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.NotNull(actual);
Assert.True(actual.AutomaticTax.Enabled);
}
[Theory]
[BitAutoData]
public void GetUpdateOptions_ThrowsArgumentNullException_WhenTaxIdsIsNull(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = true
},
Customer = new Customer
{
Address = new Address
{
Country = "ES",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
},
TaxIds = null
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
Assert.Throws<ArgumentNullException>(() => sutProvider.Sut.GetUpdateOptions(subscription));
}
[Theory]
[BitAutoData]
public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = true
},
Customer = new Customer
{
Address = new Address
{
Country = "ES",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
},
TaxIds = new StripeList<TaxId>
{
Data = new List<TaxId>()
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.NotNull(actual);
Assert.False(actual.AutomaticTax.Enabled);
}
[Theory]
[BitAutoData]
public void SetUpdateOptions_SetsNothing_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var options = new SubscriptionUpdateOptions();
var subscription = new Subscription
{
Customer = new Customer
{
Address = new()
{
Country = "US"
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(false);
sutProvider.Sut.SetUpdateOptions(options, subscription);
Assert.Null(options.AutomaticTax);
}
[Theory]
[BitAutoData]
public void SetUpdateOptions_SetsNothing_WhenSubscriptionDoesNotNeedUpdating(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var options = new SubscriptionUpdateOptions();
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = true
},
Customer = new Customer
{
Address = new Address
{
Country = "US",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
sutProvider.Sut.SetUpdateOptions(options, subscription);
Assert.Null(options.AutomaticTax);
}
[Theory]
[BitAutoData]
public void SetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var options = new SubscriptionUpdateOptions();
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = true
},
Customer = new Customer
{
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
sutProvider.Sut.SetUpdateOptions(options, subscription);
Assert.False(options.AutomaticTax.Enabled);
}
[Theory]
[BitAutoData]
public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForAmericanCustomers(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var options = new SubscriptionUpdateOptions();
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = false
},
Customer = new Customer
{
Address = new Address
{
Country = "US",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
sutProvider.Sut.SetUpdateOptions(options, subscription);
Assert.True(options.AutomaticTax!.Enabled);
}
[Theory]
[BitAutoData]
public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var options = new SubscriptionUpdateOptions();
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = false
},
Customer = new Customer
{
Address = new Address
{
Country = "ES",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
},
TaxIds = new StripeList<TaxId>
{
Data = new List<TaxId>
{
new()
{
Country = "ES",
Type = "eu_vat",
Value = "ESZ8880999Z"
}
}
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
sutProvider.Sut.SetUpdateOptions(options, subscription);
Assert.True(options.AutomaticTax!.Enabled);
}
[Theory]
[BitAutoData]
public void SetUpdateOptions_ThrowsArgumentNullException_WhenTaxIdsIsNull(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var options = new SubscriptionUpdateOptions();
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = true
},
Customer = new Customer
{
Address = new Address
{
Country = "ES",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
},
TaxIds = null
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
Assert.Throws<ArgumentNullException>(() => sutProvider.Sut.SetUpdateOptions(options, subscription));
}
[Theory]
[BitAutoData]
public void SetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds(
SutProvider<BusinessUseAutomaticTaxStrategy> sutProvider)
{
var options = new SubscriptionUpdateOptions();
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = true
},
Customer = new Customer
{
Address = new Address
{
Country = "ES",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
},
TaxIds = new StripeList<TaxId>
{
Data = new List<TaxId>()
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
sutProvider.Sut.SetUpdateOptions(options, subscription);
Assert.False(options.AutomaticTax!.Enabled);
}
}

View File

@ -0,0 +1,217 @@
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Stripe;
using Xunit;
namespace Bit.Core.Test.Billing.Services.Implementations.AutomaticTax;
[SutProviderCustomize]
public class PersonalUseAutomaticTaxStrategyTests
{
[Theory]
[BitAutoData]
public void GetUpdateOptions_ReturnsNull_WhenFeatureFlagAllowingToUpdateSubscriptionsIsDisabled(
SutProvider<PersonalUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription();
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(false);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.Null(actual);
}
[Theory]
[BitAutoData]
public void GetUpdateOptions_ReturnsNull_WhenSubscriptionDoesNotNeedUpdating(
SutProvider<PersonalUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = true
},
Customer = new Customer
{
Address = new Address
{
Country = "US",
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.Null(actual);
}
[Theory]
[BitAutoData]
public void GetUpdateOptions_SetsAutomaticTaxToFalse_WhenTaxLocationIsUnrecognizedOrInvalid(
SutProvider<PersonalUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = true
},
Customer = new Customer
{
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.UnrecognizedLocation
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.NotNull(actual);
Assert.False(actual.AutomaticTax.Enabled);
}
[Theory]
[BitAutoData("CA")]
[BitAutoData("ES")]
[BitAutoData("US")]
public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForAllCountries(
string country, SutProvider<PersonalUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = false
},
Customer = new Customer
{
Address = new Address
{
Country = country
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.NotNull(actual);
Assert.True(actual.AutomaticTax.Enabled);
}
[Theory]
[BitAutoData("CA")]
[BitAutoData("ES")]
[BitAutoData("US")]
public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithTaxIds(
string country, SutProvider<PersonalUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = false
},
Customer = new Customer
{
Address = new Address
{
Country = country,
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
},
TaxIds = new StripeList<TaxId>
{
Data = new List<TaxId>
{
new()
{
Country = "ES",
Type = "eu_vat",
Value = "ESZ8880999Z"
}
}
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.NotNull(actual);
Assert.True(actual.AutomaticTax.Enabled);
}
[Theory]
[BitAutoData("CA")]
[BitAutoData("ES")]
[BitAutoData("US")]
public void GetUpdateOptions_SetsAutomaticTaxToTrue_ForGlobalCustomersWithoutTaxIds(
string country, SutProvider<PersonalUseAutomaticTaxStrategy> sutProvider)
{
var subscription = new Subscription
{
AutomaticTax = new SubscriptionAutomaticTax
{
Enabled = false
},
Customer = new Customer
{
Address = new Address
{
Country = country
},
Tax = new CustomerTax
{
AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported
},
TaxIds = new StripeList<TaxId>
{
Data = new List<TaxId>()
}
}
};
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(Arg.Is<string>(p => p == FeatureFlagKeys.PM19422_AllowAutomaticTaxUpdates))
.Returns(true);
var actual = sutProvider.Sut.GetUpdateOptions(subscription);
Assert.NotNull(actual);
Assert.True(actual.AutomaticTax.Enabled);
}
}

View File

@ -0,0 +1,105 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models.StaticStore.Plans;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Entities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Billing.Services.Implementations;
[SutProviderCustomize]
public class AutomaticTaxFactoryTests
{
[BitAutoData]
[Theory]
public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenSubscriberIsUser(SutProvider<AutomaticTaxFactory> sut)
{
var parameters = new AutomaticTaxFactoryParameters(new User(), []);
var actual = await sut.Sut.CreateAsync(parameters);
Assert.IsType<PersonalUseAutomaticTaxStrategy>(actual);
}
[BitAutoData]
[Theory]
public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenSubscriberIsOrganizationWithFamiliesAnnuallyPrice(
SutProvider<AutomaticTaxFactory> sut)
{
var familiesPlan = new FamiliesPlan();
var parameters = new AutomaticTaxFactoryParameters(new Organization(), [familiesPlan.PasswordManager.StripePlanId]);
sut.GetDependency<IPricingClient>()
.GetPlanOrThrow(Arg.Is<PlanType>(p => p == PlanType.FamiliesAnnually))
.Returns(new FamiliesPlan());
sut.GetDependency<IPricingClient>()
.GetPlanOrThrow(Arg.Is<PlanType>(p => p == PlanType.FamiliesAnnually2019))
.Returns(new Families2019Plan());
var actual = await sut.Sut.CreateAsync(parameters);
Assert.IsType<PersonalUseAutomaticTaxStrategy>(actual);
}
[Theory]
[BitAutoData]
public async Task CreateAsync_ReturnsBusinessUseStrategy_WhenSubscriberIsOrganizationWithBusinessUsePrice(
EnterpriseAnnually plan,
SutProvider<AutomaticTaxFactory> sut)
{
var parameters = new AutomaticTaxFactoryParameters(new Organization(), [plan.PasswordManager.StripePlanId]);
sut.GetDependency<IPricingClient>()
.GetPlanOrThrow(Arg.Is<PlanType>(p => p == PlanType.FamiliesAnnually))
.Returns(new FamiliesPlan());
sut.GetDependency<IPricingClient>()
.GetPlanOrThrow(Arg.Is<PlanType>(p => p == PlanType.FamiliesAnnually2019))
.Returns(new Families2019Plan());
var actual = await sut.Sut.CreateAsync(parameters);
Assert.IsType<BusinessUseAutomaticTaxStrategy>(actual);
}
[Theory]
[BitAutoData]
public async Task CreateAsync_ReturnsPersonalUseStrategy_WhenPlanIsMeantForPersonalUse(SutProvider<AutomaticTaxFactory> sut)
{
var parameters = new AutomaticTaxFactoryParameters(PlanType.FamiliesAnnually);
sut.GetDependency<IPricingClient>()
.GetPlanOrThrow(Arg.Is<PlanType>(p => p == parameters.PlanType.Value))
.Returns(new FamiliesPlan());
var actual = await sut.Sut.CreateAsync(parameters);
Assert.IsType<PersonalUseAutomaticTaxStrategy>(actual);
}
[Theory]
[BitAutoData]
public async Task CreateAsync_ReturnsBusinessUseStrategy_WhenPlanIsMeantForBusinessUse(SutProvider<AutomaticTaxFactory> sut)
{
var parameters = new AutomaticTaxFactoryParameters(PlanType.EnterpriseAnnually);
sut.GetDependency<IPricingClient>()
.GetPlanOrThrow(Arg.Is<PlanType>(p => p == parameters.PlanType.Value))
.Returns(new EnterprisePlan(true));
var actual = await sut.Sut.CreateAsync(parameters);
Assert.IsType<BusinessUseAutomaticTaxStrategy>(actual);
}
public record EnterpriseAnnually : EnterprisePlan
{
public EnterpriseAnnually() : base(true)
{
}
}
}

View File

@ -3,10 +3,13 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Caches; using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Test.Billing.Stubs;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
using Braintree; using Braintree;
@ -1167,7 +1170,9 @@ public class SubscriberServiceTests
{ {
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId) stripeAdapter.CustomerGetAsync(
provider.GatewayCustomerId,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")))
.Returns(new Customer .Returns(new Customer
{ {
Id = provider.GatewayCustomerId, Id = provider.GatewayCustomerId,
@ -1213,7 +1218,10 @@ public class SubscriberServiceTests
{ {
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId) stripeAdapter.CustomerGetAsync(
provider.GatewayCustomerId,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))
)
.Returns(new Customer .Returns(new Customer
{ {
Id = provider.GatewayCustomerId, Id = provider.GatewayCustomerId,
@ -1321,7 +1329,9 @@ public class SubscriberServiceTests
{ {
const string braintreeCustomerId = "braintree_customer_id"; const string braintreeCustomerId = "braintree_customer_id";
sutProvider.GetDependency<IStripeAdapter>().CustomerGetAsync(provider.GatewayCustomerId) sutProvider.GetDependency<IStripeAdapter>().CustomerGetAsync(
provider.GatewayCustomerId,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")))
.Returns(new Customer .Returns(new Customer
{ {
Id = provider.GatewayCustomerId, Id = provider.GatewayCustomerId,
@ -1373,7 +1383,9 @@ public class SubscriberServiceTests
{ {
const string braintreeCustomerId = "braintree_customer_id"; const string braintreeCustomerId = "braintree_customer_id";
sutProvider.GetDependency<IStripeAdapter>().CustomerGetAsync(provider.GatewayCustomerId) sutProvider.GetDependency<IStripeAdapter>().CustomerGetAsync(
provider.GatewayCustomerId,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")))
.Returns(new Customer .Returns(new Customer
{ {
Id = provider.GatewayCustomerId, Id = provider.GatewayCustomerId,
@ -1482,7 +1494,9 @@ public class SubscriberServiceTests
{ {
const string braintreeCustomerId = "braintree_customer_id"; const string braintreeCustomerId = "braintree_customer_id";
sutProvider.GetDependency<IStripeAdapter>().CustomerGetAsync(provider.GatewayCustomerId) sutProvider.GetDependency<IStripeAdapter>().CustomerGetAsync(
provider.GatewayCustomerId,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")))
.Returns(new Customer .Returns(new Customer
{ {
Id = provider.GatewayCustomerId Id = provider.GatewayCustomerId
@ -1561,6 +1575,37 @@ public class SubscriberServiceTests
"Example Town", "Example Town",
"NY"); "NY");
sutProvider.GetDependency<IStripeAdapter>()
.CustomerUpdateAsync(
Arg.Is<string>(p => p == provider.GatewayCustomerId),
Arg.Is<CustomerUpdateOptions>(options =>
options.Address.Country == "US" &&
options.Address.PostalCode == "12345" &&
options.Address.Line1 == "123 Example St." &&
options.Address.Line2 == null &&
options.Address.City == "Example Town" &&
options.Address.State == "NY"))
.Returns(new Customer
{
Id = provider.GatewayCustomerId,
Address = new Address
{
Country = "US",
PostalCode = "12345",
Line1 = "123 Example St.",
Line2 = null,
City = "Example Town",
State = "NY"
},
TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] }
});
var subscription = new Subscription { Items = new StripeList<SubscriptionItem>() };
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(Arg.Any<string>())
.Returns(subscription);
sutProvider.GetDependency<IAutomaticTaxFactory>().CreateAsync(Arg.Any<AutomaticTaxFactoryParameters>())
.Returns(new FakeAutomaticTaxStrategy(true));
await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation);
await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>( await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(

View File

@ -0,0 +1,35 @@
using Bit.Core.Billing.Services;
using Stripe;
namespace Bit.Core.Test.Billing.Stubs;
/// <param name="isAutomaticTaxEnabled">
/// Whether the subscription options will have automatic tax enabled or not.
/// </param>
public class FakeAutomaticTaxStrategy(
bool isAutomaticTaxEnabled) : IAutomaticTaxStrategy
{
public SubscriptionUpdateOptions? GetUpdateOptions(Subscription subscription)
{
return new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled }
};
}
public void SetCreateOptions(SubscriptionCreateOptions options, Customer customer)
{
options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled };
}
public void SetUpdateOptions(SubscriptionUpdateOptions options, Subscription subscription)
{
options.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled };
}
public void SetInvoiceCreatePreviewOptions(InvoiceCreatePreviewOptions options)
{
options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = isAutomaticTaxEnabled };
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,25 @@
-- Recreate the NotificationStatusView to include the Notification.TaskId column
CREATE OR ALTER VIEW [dbo].[NotificationStatusDetailsView]
AS
SELECT
N.[Id],
N.[Priority],
N.[Global],
N.[ClientType],
N.[UserId],
N.[OrganizationId],
N.[Title],
N.[Body],
N.[CreationDate],
N.[RevisionDate],
N.[TaskId],
NS.[UserId] AS [NotificationStatusUserId],
NS.[ReadDate],
NS.[DeletedDate]
FROM
[dbo].[Notification] AS N
LEFT JOIN
[dbo].[NotificationStatus] as NS
ON
N.[Id] = NS.[NotificationId]
GO