1
0
mirror of https://github.com/bitwarden/server.git synced 2025-05-25 13:24:50 -05:00

Merge branch 'main' into ac/pm-13274/unified-adding-a-group-to-a-collection-returns-500-error---but-works

This commit is contained in:
Thomas Rittson 2025-05-20 08:26:03 +10:00 committed by GitHub
commit cca7b43b0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
179 changed files with 13598 additions and 1507 deletions

3
.github/CODEOWNERS vendored
View File

@ -90,6 +90,9 @@ src/Admin/Views/Tools @bitwarden/team-billing-dev
.github/workflows/test-database.yml @bitwarden/team-platform-dev .github/workflows/test-database.yml @bitwarden/team-platform-dev
.github/workflows/test.yml @bitwarden/team-platform-dev .github/workflows/test.yml @bitwarden/team-platform-dev
**/*Platform* @bitwarden/team-platform-dev **/*Platform* @bitwarden/team-platform-dev
**/.dockerignore @bitwarden/team-platform-dev
**/Dockerfile @bitwarden/team-platform-dev
**/entrypoint.sh @bitwarden/team-platform-dev
# Multiple owners - DO NOT REMOVE (BRE) # Multiple owners - DO NOT REMOVE (BRE)
**/packages.lock.json **/packages.lock.json

View File

@ -2,7 +2,9 @@ name: Build on PR Target
on: on:
pull_request_target: pull_request_target:
types: [opened, synchronize] types: [opened, synchronize, reopened]
branches:
- "main"
defaults: defaults:
run: run:

View File

@ -7,8 +7,14 @@ on:
- "main" - "main"
- "rc" - "rc"
- "hotfix-rc" - "hotfix-rc"
pull_request:
types: [opened, synchronize, reopened]
branches-ignore:
- main
pull_request_target: pull_request_target:
types: [opened, synchronize] types: [opened, synchronize, reopened]
branches:
- "main"
jobs: jobs:
check-run: check-run:

View File

@ -3,7 +3,7 @@
<PropertyGroup> <PropertyGroup>
<TargetFramework>net8.0</TargetFramework> <TargetFramework>net8.0</TargetFramework>
<Version>2025.5.0</Version> <Version>2025.5.1</Version>
<RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace> <RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace>
<ImplicitUsings>enable</ImplicitUsings> <ImplicitUsings>enable</ImplicitUsings>

View File

@ -8,12 +8,10 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Microsoft.Extensions.DependencyInjection;
using Stripe; using Stripe;
namespace Bit.Commercial.Core.AdminConsole.Providers; namespace Bit.Commercial.Core.AdminConsole.Providers;
@ -23,7 +21,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
private readonly IEventService _eventService; private readonly IEventService _eventService;
private readonly IMailService _mailService; private readonly IMailService _mailService;
private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationService _organizationService;
private readonly IProviderOrganizationRepository _providerOrganizationRepository; private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IStripeAdapter _stripeAdapter; private readonly IStripeAdapter _stripeAdapter;
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
@ -31,26 +28,22 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
private readonly ISubscriberService _subscriberService; private readonly ISubscriberService _subscriberService;
private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery; private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
private readonly IAutomaticTaxStrategy _automaticTaxStrategy;
public RemoveOrganizationFromProviderCommand( public RemoveOrganizationFromProviderCommand(
IEventService eventService, IEventService eventService,
IMailService mailService, IMailService mailService,
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
IOrganizationService organizationService,
IProviderOrganizationRepository providerOrganizationRepository, IProviderOrganizationRepository providerOrganizationRepository,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
IFeatureService featureService, IFeatureService featureService,
IProviderBillingService providerBillingService, IProviderBillingService providerBillingService,
ISubscriberService subscriberService, ISubscriberService subscriberService,
IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery,
IPricingClient pricingClient, IPricingClient pricingClient)
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
{ {
_eventService = eventService; _eventService = eventService;
_mailService = mailService; _mailService = mailService;
_organizationRepository = organizationRepository; _organizationRepository = organizationRepository;
_organizationService = organizationService;
_providerOrganizationRepository = providerOrganizationRepository; _providerOrganizationRepository = providerOrganizationRepository;
_stripeAdapter = stripeAdapter; _stripeAdapter = stripeAdapter;
_featureService = featureService; _featureService = featureService;
@ -58,7 +51,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
_subscriberService = subscriberService; _subscriberService = subscriberService;
_hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery; _hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery;
_pricingClient = pricingClient; _pricingClient = pricingClient;
_automaticTaxStrategy = automaticTaxStrategy;
} }
public async Task RemoveOrganizationFromProvider( public async Task RemoveOrganizationFromProvider(
@ -76,7 +68,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync( if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync(
providerOrganization.OrganizationId, providerOrganization.OrganizationId,
Array.Empty<Guid>(), [],
includeProvider: false)) includeProvider: false))
{ {
throw new BadRequestException("Organization must have at least one confirmed owner."); throw new BadRequestException("Organization must have at least one confirmed owner.");
@ -101,7 +93,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
/// <summary> /// <summary>
/// When a client organization is unlinked from a provider, we have to check if they're Stripe-enabled /// When a client organization is unlinked from a provider, we have to check if they're Stripe-enabled
/// and, if they are, we remove their MSP discount and set their Subscription to `send_invoice`. This is because /// and, if they are, we remove their MSP discount and set their Subscription to `send_invoice`. This is because
/// the provider's payment method will be removed from their Stripe customer causing ensuing charges to fail. Lastly, /// the provider's payment method will be removed from their Stripe customer, causing ensuing charges to fail. Lastly,
/// we email the organization owners letting them know they need to add a new payment method. /// we email the organization owners letting them know they need to add a new payment method.
/// </summary> /// </summary>
private async Task ResetOrganizationBillingAsync( private async Task ResetOrganizationBillingAsync(
@ -141,15 +133,18 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }] Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }]
}; };
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var setNonUSBusinessUseToReverseCharge = _featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge)
{ {
_automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
} }
else else if (customer.HasRecognizedTaxLocation())
{ {
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{ {
Enabled = true Enabled = customer.Address.Country == "US" ||
customer.TaxIds.Any()
}; };
} }
@ -186,7 +181,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
await _mailService.SendProviderUpdatePaymentMethod( await _mailService.SendProviderUpdatePaymentMethod(
organization.Id, organization.Id,
organization.Name, organization.Name,
provider.Name, provider.Name!,
organizationOwnerEmails); organizationOwnerEmails);
} }
} }

View File

@ -67,6 +67,7 @@ public class BusinessUnitConverter(
organization.MaxStorageGb = updatedPlan.PasswordManager.BaseStorageGb; organization.MaxStorageGb = updatedPlan.PasswordManager.BaseStorageGb;
organization.UsePolicies = updatedPlan.HasPolicies; organization.UsePolicies = updatedPlan.HasPolicies;
organization.UseSso = updatedPlan.HasSso; organization.UseSso = updatedPlan.HasSso;
organization.UseOrganizationDomains = updatedPlan.HasOrganizationDomains;
organization.UseGroups = updatedPlan.HasGroups; organization.UseGroups = updatedPlan.HasGroups;
organization.UseEvents = updatedPlan.HasEvents; organization.UseEvents = updatedPlan.HasEvents;
organization.UseDirectory = updatedPlan.HasDirectory; organization.UseDirectory = updatedPlan.HasDirectory;

View File

@ -16,7 +16,8 @@ using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
@ -25,7 +26,6 @@ using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Braintree; using Braintree;
using CsvHelper; using CsvHelper;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Stripe; using Stripe;
@ -50,8 +50,7 @@ public class ProviderBillingService(
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService, ISubscriberService subscriberService,
ITaxService taxService, ITaxService taxService)
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
: IProviderBillingService : IProviderBillingService
{ {
public async Task AddExistingOrganization( public async Task AddExistingOrganization(
@ -97,6 +96,7 @@ public class ProviderBillingService(
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.UsePolicies = plan.HasPolicies; organization.UsePolicies = plan.HasPolicies;
organization.UseSso = plan.HasSso; organization.UseSso = plan.HasSso;
organization.UseOrganizationDomains = plan.HasOrganizationDomains;
organization.UseGroups = plan.HasGroups; organization.UseGroups = plan.HasGroups;
organization.UseEvents = plan.HasEvents; organization.UseEvents = plan.HasEvents;
organization.UseDirectory = plan.HasDirectory; organization.UseDirectory = plan.HasDirectory;
@ -125,7 +125,7 @@ public class ProviderBillingService(
/* /*
* We have to scale the provider's seats before the ProviderOrganization * We have to scale the provider's seats before the ProviderOrganization
* row is inserted so the added organization's seats don't get double counted. * row is inserted so the added organization's seats don't get double-counted.
*/ */
await ScaleSeats(provider, organization.PlanType, organization.Seats!.Value); await ScaleSeats(provider, organization.PlanType, organization.Seats!.Value);
@ -233,7 +233,7 @@ public class ProviderBillingService(
var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions
{ {
Expand = ["tax_ids"] Expand = ["tax", "tax_ids"]
}); });
var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); var providerTaxId = providerCustomer.TaxIds.FirstOrDefault();
@ -281,6 +281,13 @@ public class ProviderBillingService(
] ]
}; };
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge && providerCustomer.Address is not { Country: "US" })
{
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
organization.GatewayCustomerId = customer.Id; organization.GatewayCustomerId = customer.Id;
@ -517,6 +524,13 @@ public class ProviderBillingService(
} }
}; };
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge && taxInfo.BillingAddressCountry != "US")
{
options.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber)) if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber))
{ {
var taxIdType = taxService.GetStripeTaxCode( var taxIdType = taxService.GetStripeTaxCode(
@ -528,6 +542,7 @@ public class ProviderBillingService(
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
taxInfo.BillingAddressCountry, taxInfo.BillingAddressCountry,
taxInfo.TaxIdNumber); taxInfo.TaxIdNumber);
throw new BadRequestException("billingTaxIdTypeInferenceError"); throw new BadRequestException("billingTaxIdTypeInferenceError");
} }
@ -692,6 +707,13 @@ public class ProviderBillingService(
customer.Metadata.ContainsKey(BraintreeCustomerIdKey) || customer.Metadata.ContainsKey(BraintreeCustomerIdKey) ||
setupIntent.IsUnverifiedBankAccount()); setupIntent.IsUnverifiedBankAccount());
int? trialPeriodDays = provider.Type switch
{
ProviderType.Msp when usePaymentMethod => 14,
ProviderType.BusinessUnit when usePaymentMethod => 4,
_ => null
};
var subscriptionCreateOptions = new SubscriptionCreateOptions var subscriptionCreateOptions = new SubscriptionCreateOptions
{ {
CollectionMethod = usePaymentMethod ? CollectionMethod = usePaymentMethod ?
@ -705,17 +727,24 @@ public class ProviderBillingService(
}, },
OffSession = true, OffSession = true,
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
TrialPeriodDays = usePaymentMethod ? 14 : null TrialPeriodDays = trialPeriodDays
}; };
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var setNonUSBusinessUseToReverseCharge =
{ featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
} if (setNonUSBusinessUseToReverseCharge)
else
{ {
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
} }
else if (customer.HasRecognizedTaxLocation())
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = customer.Address.Country == "US" ||
customer.TaxIds.Any()
};
}
try try
{ {

View File

@ -1,13 +1,9 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Licenses.Extensions;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.SecretsManager.Queries.Projects.Interfaces; using Bit.Core.SecretsManager.Queries.Projects.Interfaces;
using Bit.Core.SecretsManager.Repositories; using Bit.Core.SecretsManager.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
namespace Bit.Commercial.Core.SecretsManager.Queries.Projects; namespace Bit.Commercial.Core.SecretsManager.Queries.Projects;
@ -17,72 +13,42 @@ public class MaxProjectsQuery : IMaxProjectsQuery
private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationRepository _organizationRepository;
private readonly IProjectRepository _projectRepository; private readonly IProjectRepository _projectRepository;
private readonly IGlobalSettings _globalSettings; private readonly IGlobalSettings _globalSettings;
private readonly ILicensingService _licensingService;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
public MaxProjectsQuery( public MaxProjectsQuery(
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
IProjectRepository projectRepository, IProjectRepository projectRepository,
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
ILicensingService licensingService,
IPricingClient pricingClient) IPricingClient pricingClient)
{ {
_organizationRepository = organizationRepository; _organizationRepository = organizationRepository;
_projectRepository = projectRepository; _projectRepository = projectRepository;
_globalSettings = globalSettings; _globalSettings = globalSettings;
_licensingService = licensingService;
_pricingClient = pricingClient; _pricingClient = pricingClient;
} }
public async Task<(short? max, bool? overMax)> GetByOrgIdAsync(Guid organizationId, int projectsToAdd) public async Task<(short? max, bool? overMax)> GetByOrgIdAsync(Guid organizationId, int projectsToAdd)
{ {
// "MaxProjects" only applies to free 2-person organizations, which can't be self-hosted.
if (_globalSettings.SelfHosted)
{
return (null, null);
}
var org = await _organizationRepository.GetByIdAsync(organizationId); var org = await _organizationRepository.GetByIdAsync(organizationId);
if (org == null) if (org == null)
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
var (planType, maxProjects) = await GetPlanTypeAndMaxProjectsAsync(org); var plan = await _pricingClient.GetPlan(org.PlanType);
if (planType != PlanType.Free) if (plan is not { SecretsManager: not null, Type: PlanType.Free })
{ {
return (null, null); return (null, null);
} }
var projects = await _projectRepository.GetProjectCountByOrganizationIdAsync(organizationId); var projects = await _projectRepository.GetProjectCountByOrganizationIdAsync(organizationId);
return ((short? max, bool? overMax))(projects + projectsToAdd > maxProjects ? (maxProjects, true) : (maxProjects, false)); return ((short? max, bool? overMax))(projects + projectsToAdd > plan.SecretsManager.MaxProjects ? (plan.SecretsManager.MaxProjects, true) : (plan.SecretsManager.MaxProjects, false));
}
private async Task<(PlanType planType, int maxProjects)> GetPlanTypeAndMaxProjectsAsync(Organization organization)
{
if (_globalSettings.SelfHosted)
{
var license = await _licensingService.ReadOrganizationLicenseAsync(organization);
if (license == null)
{
throw new BadRequestException("License not found.");
}
var claimsPrincipal = _licensingService.GetClaimsPrincipalFromLicense(license);
var maxProjects = claimsPrincipal.GetValue<int?>(OrganizationLicenseConstants.SmMaxProjects);
if (!maxProjects.HasValue)
{
throw new BadRequestException("License does not contain a value for max Secrets Manager projects");
}
var planType = claimsPrincipal.GetValue<PlanType>(OrganizationLicenseConstants.PlanType);
return (planType, maxProjects.Value);
}
var plan = await _pricingClient.GetPlan(organization.PlanType);
if (plan is { SupportsSecretsManager: true })
{
return (plan.Type, plan.SecretsManager.MaxProjects);
}
throw new BadRequestException("Existing plan not found.");
} }
} }

View File

@ -15,7 +15,7 @@
}, },
"devDependencies": { "devDependencies": {
"css-loader": "7.1.2", "css-loader": "7.1.2",
"expose-loader": "5.0.0", "expose-loader": "5.0.1",
"mini-css-extract-plugin": "2.9.2", "mini-css-extract-plugin": "2.9.2",
"sass": "1.85.0", "sass": "1.85.0",
"sass-loader": "16.0.4", "sass-loader": "16.0.4",
@ -1083,9 +1083,9 @@
} }
}, },
"node_modules/expose-loader": { "node_modules/expose-loader": {
"version": "5.0.0", "version": "5.0.1",
"resolved": "https://registry.npmjs.org/expose-loader/-/expose-loader-5.0.0.tgz", "resolved": "https://registry.npmjs.org/expose-loader/-/expose-loader-5.0.1.tgz",
"integrity": "sha512-BtUqYRmvx1bEY5HN6eK2I9URUZgNmN0x5UANuocaNjXSgfoDlkXt+wyEMe7i5DzDNh2BKJHPc5F4rBwEdSQX6w==", "integrity": "sha512-5YPZuszN/eWND/B+xuq5nIpb/l5TV1HYmdO6SubYtHv+HenVw9/6bn33Mm5reY8DNid7AVtbARvyUD34edfCtg==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"engines": { "engines": {

View File

@ -14,7 +14,7 @@
}, },
"devDependencies": { "devDependencies": {
"css-loader": "7.1.2", "css-loader": "7.1.2",
"expose-loader": "5.0.0", "expose-loader": "5.0.1",
"mini-css-extract-plugin": "2.9.2", "mini-css-extract-plugin": "2.9.2",
"sass": "1.85.0", "sass": "1.85.0",
"sass-loader": "16.0.4", "sass-loader": "16.0.4",

View File

@ -1,4 +1,5 @@
using Bit.Commercial.Core.AdminConsole.Providers; using Bit.Commercial.Core.AdminConsole.Providers;
using Bit.Core;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
@ -223,31 +224,115 @@ public class RemoveOrganizationFromProviderCommandTests
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(options =>
options.Description == string.Empty &&
options.Email == organization.BillingEmail &&
options.Expand[0] == "tax" &&
options.Expand[1] == "tax_ids")).Returns(new Customer
{
Id = "customer_id",
Address = new Address
{
Country = "US"
}
});
stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(new Subscription stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(new Subscription
{ {
Id = "subscription_id" Id = "subscription_id"
}); });
sutProvider.GetDependency<IAutomaticTaxStrategy>() await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization);
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options => await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == organization.GatewayCustomerId && options.Customer == organization.GatewayCustomerId &&
options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice &&
options.DaysUntilDue == 30 && options.DaysUntilDue == 30 &&
options.Metadata["organizationId"] == organization.Id.ToString() && options.AutomaticTax.Enabled == true &&
options.OffSession == true && options.Metadata["organizationId"] == organization.Id.ToString() &&
options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && options.OffSession == true &&
options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId && options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
options.Items.First().Quantity == organization.Seats) options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId &&
, Arg.Any<Customer>())) options.Items.First().Quantity == organization.Seats));
.Do(x =>
await sutProvider.GetDependency<IProviderBillingService>().Received(1)
.ScaleSeats(provider, organization.PlanType, -organization.Seats ?? 0);
await organizationRepository.Received(1).ReplaceAsync(Arg.Is<Organization>(
org =>
org.BillingEmail == "a@example.com" &&
org.GatewaySubscriptionId == "subscription_id" &&
org.Status == OrganizationStatusType.Created));
await sutProvider.GetDependency<IProviderOrganizationRepository>().Received(1)
.DeleteAsync(providerOrganization);
await sutProvider.GetDependency<IEventService>().Received(1)
.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed);
await sutProvider.GetDependency<IMailService>().Received(1)
.SendProviderUpdatePaymentMethod(
organization.Id,
organization.Name,
provider.Name,
Arg.Is<IEnumerable<string>>(emails => emails.FirstOrDefault() == "a@example.com"));
}
[Theory, BitAutoData]
public async Task RemoveOrganizationFromProvider_OrganizationStripeEnabled_ConsolidatedBilling_ReverseCharge_MakesCorrectInvocations(
Provider provider,
ProviderOrganization providerOrganization,
Organization organization,
SutProvider<RemoveOrganizationFromProviderCommand> sutProvider)
{
provider.Status = ProviderStatusType.Billable;
providerOrganization.ProviderId = provider.Id;
organization.Status = OrganizationStatusType.Managed;
organization.PlanType = PlanType.TeamsMonthly;
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan);
sutProvider.GetDependency<IHasConfirmedOwnersExceptQuery>().HasConfirmedOwnersExceptAsync(
providerOrganization.OrganizationId,
[],
includeProvider: false)
.Returns(true);
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns([
"a@example.com",
"b@example.com"
]);
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(options =>
options.Description == string.Empty &&
options.Email == organization.BillingEmail &&
options.Expand[0] == "tax" &&
options.Expand[1] == "tax_ids")).Returns(new Customer
{ {
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions Id = "customer_id",
Address = new Address
{ {
Enabled = true Country = "US"
}; }
}); });
stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(new Subscription
{
Id = "subscription_id"
});
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization);
await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options => await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options =>

View File

@ -17,6 +17,7 @@ using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
@ -261,7 +262,7 @@ public class ProviderBillingServiceTests
}; };
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>( sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>(
options => options.Expand.FirstOrDefault() == "tax_ids")) options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids")))
.Returns(providerCustomer); .Returns(providerCustomer);
sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri
@ -311,6 +312,91 @@ public class ProviderBillingServiceTests
org => org.GatewayCustomerId == "customer_id")); org => org.GatewayCustomerId == "customer_id"));
} }
[Theory, BitAutoData]
public async Task CreateCustomer_ForClientOrg_ReverseCharge_Succeeds(
Provider provider,
Organization organization,
SutProvider<ProviderBillingService> sutProvider)
{
organization.GatewayCustomerId = null;
organization.Name = "Name";
organization.BusinessName = "BusinessName";
var providerCustomer = new Customer
{
Address = new Address
{
Country = "CA",
PostalCode = "12345",
Line1 = "123 Main St.",
Line2 = "Unit 4",
City = "Fake Town",
State = "Fake State"
},
TaxIds = new StripeList<TaxId>
{
Data =
[
new TaxId { Type = "TYPE", Value = "VALUE" }
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>(
options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids")))
.Returns(providerCustomer);
sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri
.Returns(new Bit.Core.Settings.GlobalSettings.BaseServiceUriSettings(new Bit.Core.Settings.GlobalSettings())
{
CloudRegion = "US"
});
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
sutProvider.GetDependency<IStripeAdapter>().CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(
options =>
options.Address.Country == providerCustomer.Address.Country &&
options.Address.PostalCode == providerCustomer.Address.PostalCode &&
options.Address.Line1 == providerCustomer.Address.Line1 &&
options.Address.Line2 == providerCustomer.Address.Line2 &&
options.Address.City == providerCustomer.Address.City &&
options.Address.State == providerCustomer.Address.State &&
options.Name == organization.DisplayName() &&
options.Description == $"{provider.Name} Client Organization" &&
options.Email == provider.BillingEmail &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" &&
options.Metadata["region"] == "US" &&
options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type &&
options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value &&
options.TaxExempt == StripeConstants.TaxExempt.Reverse))
.Returns(new Customer { Id = "customer_id" });
await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization);
await sutProvider.GetDependency<IStripeAdapter>().Received(1).CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(
options =>
options.Address.Country == providerCustomer.Address.Country &&
options.Address.PostalCode == providerCustomer.Address.PostalCode &&
options.Address.Line1 == providerCustomer.Address.Line1 &&
options.Address.Line2 == providerCustomer.Address.Line2 &&
options.Address.City == providerCustomer.Address.City &&
options.Address.State == providerCustomer.Address.State &&
options.Name == organization.DisplayName() &&
options.Description == $"{provider.Name} Client Organization" &&
options.Email == provider.BillingEmail &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" &&
options.Metadata["region"] == "US" &&
options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type &&
options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value));
await sutProvider.GetDependency<IOrganizationRepository>().Received(1).ReplaceAsync(Arg.Is<Organization>(
org => org.GatewayCustomerId == "customer_id"));
}
#endregion #endregion
#region GenerateClientInvoiceReport #region GenerateClientInvoiceReport
@ -1181,6 +1267,62 @@ public class ProviderBillingServiceTests
Assert.Equivalent(expected, actual); Assert.Equivalent(expected, actual);
} }
[Theory, BitAutoData]
public async Task SetupCustomer_WithCard_ReverseCharge_Success(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var expected = new Customer
{
Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
};
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token");
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Email == provider.BillingEmail &&
o.PaymentMethod == tokenizedPaymentSource.Token &&
o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber &&
o.TaxExempt == StripeConstants.TaxExempt.Reverse))
.Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource);
Assert.Equivalent(expected, actual);
}
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid( public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
@ -1306,7 +1448,7 @@ public class ProviderBillingServiceTests
.Returns(new Customer .Returns(new Customer
{ {
Id = "customer_id", Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } Address = new Address { Country = "US" }
}); });
var providerPlans = new List<ProviderPlan> var providerPlans = new List<ProviderPlan>
@ -1358,7 +1500,7 @@ public class ProviderBillingServiceTests
var customer = new Customer var customer = new Customer
{ {
Id = "customer_id", Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } Address = new Address { Country = "US" }
}; };
sutProvider.GetDependency<ISubscriberService>() sutProvider.GetDependency<ISubscriberService>()
.GetCustomerOrThrow( .GetCustomerOrThrow(
@ -1398,19 +1540,6 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>( sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
sub => sub =>
sub.AutomaticTax.Enabled == true && sub.AutomaticTax.Enabled == true &&
@ -1442,11 +1571,11 @@ public class ProviderBillingServiceTests
var customer = new Customer var customer = new Customer
{ {
Id = "customer_id", Id = "customer_id",
Address = new Address { Country = "US" },
InvoiceSettings = new CustomerInvoiceSettings InvoiceSettings = new CustomerInvoiceSettings
{ {
DefaultPaymentMethodId = "pm_123" DefaultPaymentMethodId = "pm_123"
}, }
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}; };
sutProvider.GetDependency<ISubscriberService>() sutProvider.GetDependency<ISubscriberService>()
@ -1487,19 +1616,6 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IFeatureService>() sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
@ -1535,9 +1651,9 @@ public class ProviderBillingServiceTests
var customer = new Customer var customer = new Customer
{ {
Id = "customer_id", Id = "customer_id",
Address = new Address { Country = "US" },
InvoiceSettings = new CustomerInvoiceSettings(), InvoiceSettings = new CustomerInvoiceSettings(),
Metadata = new Dictionary<string, string>(), Metadata = new Dictionary<string, string>()
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}; };
sutProvider.GetDependency<ISubscriberService>() sutProvider.GetDependency<ISubscriberService>()
@ -1578,19 +1694,6 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IFeatureService>() sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
@ -1645,12 +1748,15 @@ public class ProviderBillingServiceTests
var customer = new Customer var customer = new Customer
{ {
Id = "customer_id", Id = "customer_id",
Address = new Address
{
Country = "US"
},
InvoiceSettings = new CustomerInvoiceSettings(), InvoiceSettings = new CustomerInvoiceSettings(),
Metadata = new Dictionary<string, string> Metadata = new Dictionary<string, string>
{ {
["btCustomerId"] = "braintree_customer_id" ["btCustomerId"] = "braintree_customer_id"
}, }
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}; };
sutProvider.GetDependency<ISubscriberService>() sutProvider.GetDependency<ISubscriberService>()
@ -1691,22 +1797,92 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>() sutProvider.GetDependency<IFeatureService>()
.When(x => x.SetCreateOptions( .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id") sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
, Arg.Is<Customer>(p => p == customer))) sub =>
.Do(x => sub.AutomaticTax.Enabled == true &&
sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically &&
sub.Customer == "customer_id" &&
sub.DaysUntilDue == null &&
sub.Items.Count == 2 &&
sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams &&
sub.Items.ElementAt(0).Quantity == 100 &&
sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise &&
sub.Items.ElementAt(1).Quantity == 100 &&
sub.Metadata["providerId"] == provider.Id.ToString() &&
sub.OffSession == true &&
sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
sub.TrialPeriodDays == 14)).Returns(expected);
var actual = await sutProvider.Sut.SetupSubscription(provider);
Assert.Equivalent(expected, actual);
}
[Theory, BitAutoData]
public async Task SetupSubscription_ReverseCharge_Succeeds(
SutProvider<ProviderBillingService> sutProvider,
Provider provider)
{
provider.Type = ProviderType.Msp;
provider.GatewaySubscriptionId = null;
var customer = new Customer
{
Id = "customer_id",
Address = new Address { Country = "CA" },
InvoiceSettings = new CustomerInvoiceSettings
{ {
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions DefaultPaymentMethodId = "pm_123"
{ }
Enabled = true };
};
}); sutProvider.GetDependency<ISubscriberService>()
.GetCustomerOrThrow(
provider,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer);
var providerPlans = new List<ProviderPlan>
{
new()
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.TeamsMonthly,
SeatMinimum = 100,
PurchasedSeats = 0,
AllocatedSeats = 0
},
new()
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.EnterpriseMonthly,
SeatMinimum = 100,
PurchasedSeats = 0,
AllocatedSeats = 0
}
};
foreach (var plan in providerPlans)
{
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(plan.PlanType)
.Returns(StaticStore.GetPlan(plan.PlanType));
}
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id)
.Returns(providerPlans);
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IFeatureService>() sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>( sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
sub => sub =>
sub.AutomaticTax.Enabled == true && sub.AutomaticTax.Enabled == true &&

View File

@ -1,4 +1,4 @@
using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Services.Implementations;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
using Xunit; using Xunit;

View File

@ -1,14 +1,10 @@
using System.Security.Claims; using Bit.Commercial.Core.SecretsManager.Queries.Projects;
using Bit.Commercial.Core.SecretsManager.Queries.Projects;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Licenses;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.SecretsManager.Repositories; using Bit.Core.SecretsManager.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
@ -22,11 +18,26 @@ namespace Bit.Commercial.Core.Test.SecretsManager.Queries.Projects;
[SutProviderCustomize] [SutProviderCustomize]
public class MaxProjectsQueryTests public class MaxProjectsQueryTests
{ {
[Theory]
[BitAutoData]
public async Task GetByOrgIdAsync_SelfHosted_ReturnsNulls(SutProvider<MaxProjectsQuery> sutProvider,
Guid organizationId)
{
sutProvider.GetDependency<IGlobalSettings>().SelfHosted.Returns(true);
var (max, overMax) = await sutProvider.Sut.GetByOrgIdAsync(organizationId, 1);
Assert.Null(max);
Assert.Null(overMax);
}
[Theory] [Theory]
[BitAutoData] [BitAutoData]
public async Task GetByOrgIdAsync_OrganizationIsNull_ThrowsNotFound(SutProvider<MaxProjectsQuery> sutProvider, public async Task GetByOrgIdAsync_OrganizationIsNull_ThrowsNotFound(SutProvider<MaxProjectsQuery> sutProvider,
Guid organizationId) Guid organizationId)
{ {
sutProvider.GetDependency<IGlobalSettings>().SelfHosted.Returns(false);
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(default).ReturnsNull(); sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(default).ReturnsNull();
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.GetByOrgIdAsync(organizationId, 1)); await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.GetByOrgIdAsync(organizationId, 1));
@ -35,54 +46,6 @@ public class MaxProjectsQueryTests
.GetProjectCountByOrganizationIdAsync(organizationId); .GetProjectCountByOrganizationIdAsync(organizationId);
} }
[Theory]
[BitAutoData(PlanType.FamiliesAnnually2019)]
[BitAutoData(PlanType.Custom)]
[BitAutoData(PlanType.FamiliesAnnually)]
public async Task GetByOrgIdAsync_Cloud_SmPlanIsNull_ThrowsBadRequest(PlanType planType,
SutProvider<MaxProjectsQuery> sutProvider, Organization organization)
{
organization.PlanType = planType;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
sutProvider.GetDependency<IGlobalSettings>().SelfHosted.Returns(false);
var plan = StaticStore.GetPlan(planType);
sutProvider.GetDependency<IPricingClient>().GetPlan(organization.PlanType).Returns(plan);
await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1));
await sutProvider.GetDependency<IProjectRepository>()
.DidNotReceiveWithAnyArgs()
.GetProjectCountByOrganizationIdAsync(organization.Id);
}
[Theory]
[BitAutoData]
public async Task GetByOrgIdAsync_SelfHosted_NoMaxProjectsClaim_ThrowsBadRequest(
SutProvider<MaxProjectsQuery> sutProvider, Organization organization)
{
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
sutProvider.GetDependency<IGlobalSettings>().SelfHosted.Returns(true);
var license = new OrganizationLicense();
var claimsPrincipal = new ClaimsPrincipal();
sutProvider.GetDependency<ILicensingService>().ReadOrganizationLicenseAsync(organization).Returns(license);
sutProvider.GetDependency<ILicensingService>().GetClaimsPrincipalFromLicense(license).Returns(claimsPrincipal);
await Assert.ThrowsAsync<BadRequestException>(
async () => await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1));
await sutProvider.GetDependency<IProjectRepository>()
.DidNotReceiveWithAnyArgs()
.GetProjectCountByOrganizationIdAsync(organization.Id);
}
[Theory] [Theory]
[BitAutoData(PlanType.TeamsMonthly2019)] [BitAutoData(PlanType.TeamsMonthly2019)]
[BitAutoData(PlanType.TeamsMonthly2020)] [BitAutoData(PlanType.TeamsMonthly2020)]
@ -97,57 +60,16 @@ public class MaxProjectsQueryTests
[BitAutoData(PlanType.EnterpriseAnnually2019)] [BitAutoData(PlanType.EnterpriseAnnually2019)]
[BitAutoData(PlanType.EnterpriseAnnually2020)] [BitAutoData(PlanType.EnterpriseAnnually2020)]
[BitAutoData(PlanType.EnterpriseAnnually)] [BitAutoData(PlanType.EnterpriseAnnually)]
public async Task GetByOrgIdAsync_Cloud_SmNoneFreePlans_ReturnsNull(PlanType planType, public async Task GetByOrgIdAsync_SmNoneFreePlans_ReturnsNull(PlanType planType,
SutProvider<MaxProjectsQuery> sutProvider, Organization organization) SutProvider<MaxProjectsQuery> sutProvider, Organization organization)
{ {
organization.PlanType = planType;
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
sutProvider.GetDependency<IGlobalSettings>().SelfHosted.Returns(false); sutProvider.GetDependency<IGlobalSettings>().SelfHosted.Returns(false);
var plan = StaticStore.GetPlan(planType);
sutProvider.GetDependency<IPricingClient>().GetPlan(organization.PlanType).Returns(plan);
var (limit, overLimit) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1);
Assert.Null(limit);
Assert.Null(overLimit);
await sutProvider.GetDependency<IProjectRepository>().DidNotReceiveWithAnyArgs()
.GetProjectCountByOrganizationIdAsync(organization.Id);
}
[Theory]
[BitAutoData(PlanType.TeamsMonthly2019)]
[BitAutoData(PlanType.TeamsMonthly2020)]
[BitAutoData(PlanType.TeamsMonthly)]
[BitAutoData(PlanType.TeamsAnnually2019)]
[BitAutoData(PlanType.TeamsAnnually2020)]
[BitAutoData(PlanType.TeamsAnnually)]
[BitAutoData(PlanType.TeamsStarter)]
[BitAutoData(PlanType.EnterpriseMonthly2019)]
[BitAutoData(PlanType.EnterpriseMonthly2020)]
[BitAutoData(PlanType.EnterpriseMonthly)]
[BitAutoData(PlanType.EnterpriseAnnually2019)]
[BitAutoData(PlanType.EnterpriseAnnually2020)]
[BitAutoData(PlanType.EnterpriseAnnually)]
public async Task GetByOrgIdAsync_SelfHosted_SmNoneFreePlans_ReturnsNull(PlanType planType,
SutProvider<MaxProjectsQuery> sutProvider, Organization organization)
{
organization.PlanType = planType; organization.PlanType = planType;
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
sutProvider.GetDependency<IGlobalSettings>().SelfHosted.Returns(true);
var license = new OrganizationLicense(); sutProvider.GetDependency<IPricingClient>().GetPlan(organization.PlanType)
var plan = StaticStore.GetPlan(planType); .Returns(StaticStore.GetPlan(organization.PlanType));
var claims = new List<Claim>
{
new (nameof(OrganizationLicenseConstants.PlanType), organization.PlanType.ToString()),
new (nameof(OrganizationLicenseConstants.SmMaxProjects), plan.SecretsManager.MaxProjects.ToString())
};
var identity = new ClaimsIdentity(claims, "TestAuthenticationType");
var claimsPrincipal = new ClaimsPrincipal(identity);
sutProvider.GetDependency<ILicensingService>().ReadOrganizationLicenseAsync(organization).Returns(license);
sutProvider.GetDependency<ILicensingService>().GetClaimsPrincipalFromLicense(license).Returns(claimsPrincipal);
var (limit, overLimit) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1); var (limit, overLimit) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1);
@ -183,7 +105,7 @@ public class MaxProjectsQueryTests
[BitAutoData(PlanType.Free, 3, 4, true)] [BitAutoData(PlanType.Free, 3, 4, true)]
[BitAutoData(PlanType.Free, 4, 4, true)] [BitAutoData(PlanType.Free, 4, 4, true)]
[BitAutoData(PlanType.Free, 40, 4, true)] [BitAutoData(PlanType.Free, 40, 4, true)]
public async Task GetByOrgIdAsync_Cloud_SmFreePlan__Success(PlanType planType, int projects, int projectsToAdd, bool expectedOverMax, public async Task GetByOrgIdAsync_SmFreePlan__Success(PlanType planType, int projects, int projectsToAdd, bool expectedOverMax,
SutProvider<MaxProjectsQuery> sutProvider, Organization organization) SutProvider<MaxProjectsQuery> sutProvider, Organization organization)
{ {
organization.PlanType = planType; organization.PlanType = planType;
@ -191,66 +113,8 @@ public class MaxProjectsQueryTests
sutProvider.GetDependency<IProjectRepository>().GetProjectCountByOrganizationIdAsync(organization.Id) sutProvider.GetDependency<IProjectRepository>().GetProjectCountByOrganizationIdAsync(organization.Id)
.Returns(projects); .Returns(projects);
sutProvider.GetDependency<IGlobalSettings>().SelfHosted.Returns(false); sutProvider.GetDependency<IPricingClient>().GetPlan(organization.PlanType)
var plan = StaticStore.GetPlan(planType); .Returns(StaticStore.GetPlan(organization.PlanType));
sutProvider.GetDependency<IPricingClient>().GetPlan(organization.PlanType).Returns(plan);
var (max, overMax) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, projectsToAdd);
Assert.NotNull(max);
Assert.NotNull(overMax);
Assert.Equal(3, max.Value);
Assert.Equal(expectedOverMax, overMax);
await sutProvider.GetDependency<IProjectRepository>().Received(1)
.GetProjectCountByOrganizationIdAsync(organization.Id);
}
[Theory]
[BitAutoData(PlanType.Free, 0, 1, false)]
[BitAutoData(PlanType.Free, 1, 1, false)]
[BitAutoData(PlanType.Free, 2, 1, false)]
[BitAutoData(PlanType.Free, 3, 1, true)]
[BitAutoData(PlanType.Free, 4, 1, true)]
[BitAutoData(PlanType.Free, 40, 1, true)]
[BitAutoData(PlanType.Free, 0, 2, false)]
[BitAutoData(PlanType.Free, 1, 2, false)]
[BitAutoData(PlanType.Free, 2, 2, true)]
[BitAutoData(PlanType.Free, 3, 2, true)]
[BitAutoData(PlanType.Free, 4, 2, true)]
[BitAutoData(PlanType.Free, 40, 2, true)]
[BitAutoData(PlanType.Free, 0, 3, false)]
[BitAutoData(PlanType.Free, 1, 3, true)]
[BitAutoData(PlanType.Free, 2, 3, true)]
[BitAutoData(PlanType.Free, 3, 3, true)]
[BitAutoData(PlanType.Free, 4, 3, true)]
[BitAutoData(PlanType.Free, 40, 3, true)]
[BitAutoData(PlanType.Free, 0, 4, true)]
[BitAutoData(PlanType.Free, 1, 4, true)]
[BitAutoData(PlanType.Free, 2, 4, true)]
[BitAutoData(PlanType.Free, 3, 4, true)]
[BitAutoData(PlanType.Free, 4, 4, true)]
[BitAutoData(PlanType.Free, 40, 4, true)]
public async Task GetByOrgIdAsync_SelfHosted_SmFreePlan__Success(PlanType planType, int projects, int projectsToAdd, bool expectedOverMax,
SutProvider<MaxProjectsQuery> sutProvider, Organization organization)
{
organization.PlanType = planType;
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
sutProvider.GetDependency<IProjectRepository>().GetProjectCountByOrganizationIdAsync(organization.Id)
.Returns(projects);
sutProvider.GetDependency<IGlobalSettings>().SelfHosted.Returns(true);
var license = new OrganizationLicense();
var plan = StaticStore.GetPlan(planType);
var claims = new List<Claim>
{
new (nameof(OrganizationLicenseConstants.PlanType), organization.PlanType.ToString()),
new (nameof(OrganizationLicenseConstants.SmMaxProjects), plan.SecretsManager.MaxProjects.ToString())
};
var identity = new ClaimsIdentity(claims, "TestAuthenticationType");
var claimsPrincipal = new ClaimsPrincipal(identity);
sutProvider.GetDependency<ILicensingService>().ReadOrganizationLicenseAsync(organization).Returns(license);
sutProvider.GetDependency<ILicensingService>().GetClaimsPrincipalFromLicense(license).Returns(claimsPrincipal);
var (max, overMax) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, projectsToAdd); var (max, overMax) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, projectsToAdd);

View File

@ -11,6 +11,7 @@ MAILCATCHER_PORT=1080
# Alternative databases # Alternative databases
POSTGRES_PASSWORD=SET_A_PASSWORD_HERE_123 POSTGRES_PASSWORD=SET_A_PASSWORD_HERE_123
MYSQL_ROOT_PASSWORD=SET_A_PASSWORD_HERE_123 MYSQL_ROOT_PASSWORD=SET_A_PASSWORD_HERE_123
MARIADB_ROOT_PASSWORD=SET_A_PASSWORD_HERE_123
# IdP configuration # IdP configuration
# Complete using the values from the Manage SSO page in the web vault # Complete using the values from the Manage SSO page in the web vault

View File

@ -70,6 +70,20 @@ services:
profiles: profiles:
- mysql - mysql
mariadb:
image: mariadb:10
ports:
- 4306:3306
environment:
MARIADB_USER: maria
MARIADB_PASSWORD: ${MARIADB_ROOT_PASSWORD}
MARIADB_DATABASE: vault_dev
MARIADB_RANDOM_ROOT_PASSWORD: "true"
volumes:
- mariadb_dev_data:/var/lib/mysql
profiles:
- mariadb
idp: idp:
image: kenchan0130/simplesamlphp:1.19.8 image: kenchan0130/simplesamlphp:1.19.8
container_name: idp container_name: idp

View File

@ -5,6 +5,7 @@ param(
[switch]$all, [switch]$all,
[switch]$postgres, [switch]$postgres,
[switch]$mysql, [switch]$mysql,
[switch]$mariadb,
[switch]$mssql, [switch]$mssql,
[switch]$sqlite, [switch]$sqlite,
[switch]$selfhost, [switch]$selfhost,
@ -15,11 +16,15 @@ param(
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
$currentDir = Get-Location $currentDir = Get-Location
if (!$all -and !$postgres -and !$mysql -and !$sqlite) { function Get-IsEFDatabase {
return $postgres -or $mysql -or $mariadb -or $sqlite;
}
if (!$all -and !$(Get-IsEFDatabase)) {
$mssql = $true; $mssql = $true;
} }
if ($all -or $postgres -or $mysql -or $sqlite) { if ($all -or $(Get-IsEFDatabase)) {
dotnet ef *> $null dotnet ef *> $null
if ($LASTEXITCODE -ne 0) { if ($LASTEXITCODE -ne 0) {
Write-Host "Entity Framework Core tools were not found in the dotnet global tools. Attempting to install" Write-Host "Entity Framework Core tools were not found in the dotnet global tools. Attempting to install"
@ -60,9 +65,12 @@ if ($all -or $mssql) {
} }
Foreach ($item in @( Foreach ($item in @(
@($mysql, "MySQL", "MySqlMigrations", "mySql", 2),
@($postgres, "PostgreSQL", "PostgresMigrations", "postgreSql", 0), @($postgres, "PostgreSQL", "PostgresMigrations", "postgreSql", 0),
@($sqlite, "SQLite", "SqliteMigrations", "sqlite", 1) @($sqlite, "SQLite", "SqliteMigrations", "sqlite", 1),
@($mysql, "MySQL", "MySqlMigrations", "mySql", 2),
# MariaDB shares the MySQL connection string in the server config so they are mutually exclusive in that context.
# However they can still be run independently for integration tests.
@($mariadb, "MariaDB", "MySqlMigrations", "mySql", 3)
)) { )) {
if (!$item[0] -and !$all) { if (!$item[0] -and !$all) {
continue continue

View File

@ -40,8 +40,6 @@ export function authenticate(
payload["deviceName"] = "chrome"; payload["deviceName"] = "chrome";
payload["username"] = username; payload["username"] = username;
payload["password"] = password; payload["password"] = password;
params.headers["Auth-Email"] = encoding.b64encode(username);
} else { } else {
payload["scope"] = "api.organization"; payload["scope"] = "api.organization";
payload["grant_type"] = "client_credentials"; payload["grant_type"] = "client_credentials";

View File

@ -462,6 +462,7 @@ public class OrganizationsController : Controller
organization.UsersGetPremium = model.UsersGetPremium; organization.UsersGetPremium = model.UsersGetPremium;
organization.UseSecretsManager = model.UseSecretsManager; organization.UseSecretsManager = model.UseSecretsManager;
organization.UseRiskInsights = model.UseRiskInsights; organization.UseRiskInsights = model.UseRiskInsights;
organization.UseOrganizationDomains = model.UseOrganizationDomains;
organization.UseAdminSponsoredFamilies = model.UseAdminSponsoredFamilies; organization.UseAdminSponsoredFamilies = model.UseAdminSponsoredFamilies;
//secrets //secrets

View File

@ -102,7 +102,7 @@ public class OrganizationEditModel : OrganizationViewModel
MaxAutoscaleSmSeats = org.MaxAutoscaleSmSeats; MaxAutoscaleSmSeats = org.MaxAutoscaleSmSeats;
SmServiceAccounts = org.SmServiceAccounts; SmServiceAccounts = org.SmServiceAccounts;
MaxAutoscaleSmServiceAccounts = org.MaxAutoscaleSmServiceAccounts; MaxAutoscaleSmServiceAccounts = org.MaxAutoscaleSmServiceAccounts;
UseOrganizationDomains = org.UseOrganizationDomains;
_plans = plans; _plans = plans;
} }
@ -186,6 +186,8 @@ public class OrganizationEditModel : OrganizationViewModel
public int? SmServiceAccounts { get; set; } public int? SmServiceAccounts { get; set; }
[Display(Name = "Max Autoscale Machine Accounts")] [Display(Name = "Max Autoscale Machine Accounts")]
public int? MaxAutoscaleSmServiceAccounts { get; set; } public int? MaxAutoscaleSmServiceAccounts { get; set; }
[Display(Name = "Use Organization Domains")]
public bool UseOrganizationDomains { get; set; }
/** /**
* Creates a Plan[] object for use in Javascript * Creates a Plan[] object for use in Javascript
@ -215,6 +217,7 @@ public class OrganizationEditModel : OrganizationViewModel
Has2fa = p.Has2fa, Has2fa = p.Has2fa,
HasApi = p.HasApi, HasApi = p.HasApi,
HasSso = p.HasSso, HasSso = p.HasSso,
HasOrganizationDomains = p.HasOrganizationDomains,
HasKeyConnector = p.HasKeyConnector, HasKeyConnector = p.HasKeyConnector,
HasScim = p.HasScim, HasScim = p.HasScim,
HasResetPassword = p.HasResetPassword, HasResetPassword = p.HasResetPassword,
@ -315,6 +318,7 @@ public class OrganizationEditModel : OrganizationViewModel
existingOrganization.MaxAutoscaleSmSeats = MaxAutoscaleSmSeats; existingOrganization.MaxAutoscaleSmSeats = MaxAutoscaleSmSeats;
existingOrganization.SmServiceAccounts = SmServiceAccounts; existingOrganization.SmServiceAccounts = SmServiceAccounts;
existingOrganization.MaxAutoscaleSmServiceAccounts = MaxAutoscaleSmServiceAccounts; existingOrganization.MaxAutoscaleSmServiceAccounts = MaxAutoscaleSmServiceAccounts;
existingOrganization.UseOrganizationDomains = UseOrganizationDomains;
return existingOrganization; return existingOrganization;
} }
} }

View File

@ -124,6 +124,10 @@
<input type="checkbox" class="form-check-input" asp-for="UseSso" disabled='@(canEditPlan ? null : "disabled")'> <input type="checkbox" class="form-check-input" asp-for="UseSso" disabled='@(canEditPlan ? null : "disabled")'>
<label class="form-check-label" asp-for="UseSso"></label> <label class="form-check-label" asp-for="UseSso"></label>
</div> </div>
<div class="form-check">
<input type="checkbox" class="form-check-input" asp-for="UseOrganizationDomains" disabled='@(canEditPlan ? null : "disabled")'>
<label class="form-check-label" asp-for="UseOrganizationDomains"></label>
</div>
<div class="form-check"> <div class="form-check">
<input type="checkbox" class="form-check-input" asp-for="UseKeyConnector" disabled='@(canEditPlan ? null : "disabled")'> <input type="checkbox" class="form-check-input" asp-for="UseKeyConnector" disabled='@(canEditPlan ? null : "disabled")'>
<label class="form-check-label" asp-for="UseKeyConnector"></label> <label class="form-check-label" asp-for="UseKeyConnector"></label>

View File

@ -69,6 +69,7 @@
document.getElementById('@(nameof(Model.UseGroups))').checked = plan.hasGroups; document.getElementById('@(nameof(Model.UseGroups))').checked = plan.hasGroups;
document.getElementById('@(nameof(Model.UsePolicies))').checked = plan.hasPolicies; document.getElementById('@(nameof(Model.UsePolicies))').checked = plan.hasPolicies;
document.getElementById('@(nameof(Model.UseSso))').checked = plan.hasSso; document.getElementById('@(nameof(Model.UseSso))').checked = plan.hasSso;
document.getElementById('@(nameof(Model.UseOrganizationDomains))').checked = plan.hasOrganizationDomains;
document.getElementById('@(nameof(Model.UseScim))').checked = plan.hasScim; document.getElementById('@(nameof(Model.UseScim))').checked = plan.hasScim;
document.getElementById('@(nameof(Model.UseDirectory))').checked = plan.hasDirectory; document.getElementById('@(nameof(Model.UseDirectory))').checked = plan.hasDirectory;
document.getElementById('@(nameof(Model.UseEvents))').checked = plan.hasEvents; document.getElementById('@(nameof(Model.UseEvents))').checked = plan.hasEvents;

View File

@ -17,7 +17,7 @@ public class ChargeBraintreeModel : IValidatableObject
{ {
if (Id != null) if (Id != null)
{ {
if (Id.Length != 36 || (Id[0] != 'o' && Id[0] != 'u') || if (Id.Length != 36 || (Id[0] != 'o' && Id[0] != 'u' && Id[0] != 'p') ||
!Guid.TryParse(Id.Substring(1, 32), out var guid)) !Guid.TryParse(Id.Substring(1, 32), out var guid))
{ {
yield return new ValidationResult("Customer Id is not a valid format."); yield return new ValidationResult("Customer Id is not a valid format.");

View File

@ -16,7 +16,7 @@
}, },
"devDependencies": { "devDependencies": {
"css-loader": "7.1.2", "css-loader": "7.1.2",
"expose-loader": "5.0.0", "expose-loader": "5.0.1",
"mini-css-extract-plugin": "2.9.2", "mini-css-extract-plugin": "2.9.2",
"sass": "1.85.0", "sass": "1.85.0",
"sass-loader": "16.0.4", "sass-loader": "16.0.4",
@ -1084,9 +1084,9 @@
} }
}, },
"node_modules/expose-loader": { "node_modules/expose-loader": {
"version": "5.0.0", "version": "5.0.1",
"resolved": "https://registry.npmjs.org/expose-loader/-/expose-loader-5.0.0.tgz", "resolved": "https://registry.npmjs.org/expose-loader/-/expose-loader-5.0.1.tgz",
"integrity": "sha512-BtUqYRmvx1bEY5HN6eK2I9URUZgNmN0x5UANuocaNjXSgfoDlkXt+wyEMe7i5DzDNh2BKJHPc5F4rBwEdSQX6w==", "integrity": "sha512-5YPZuszN/eWND/B+xuq5nIpb/l5TV1HYmdO6SubYtHv+HenVw9/6bn33Mm5reY8DNid7AVtbARvyUD34edfCtg==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"engines": { "engines": {

View File

@ -15,7 +15,7 @@
}, },
"devDependencies": { "devDependencies": {
"css-loader": "7.1.2", "css-loader": "7.1.2",
"expose-loader": "5.0.0", "expose-loader": "5.0.1",
"mini-css-extract-plugin": "2.9.2", "mini-css-extract-plugin": "2.9.2",
"sass": "1.85.0", "sass": "1.85.0",
"sass-loader": "16.0.4", "sass-loader": "16.0.4",

View File

@ -75,6 +75,8 @@ public class OrganizationCreateRequestModel : IValidatableObject
public string InitiationPath { get; set; } public string InitiationPath { get; set; }
public bool SkipTrial { get; set; }
public virtual OrganizationSignup ToOrganizationSignup(User user) public virtual OrganizationSignup ToOrganizationSignup(User user)
{ {
var orgSignup = new OrganizationSignup var orgSignup = new OrganizationSignup
@ -107,6 +109,7 @@ public class OrganizationCreateRequestModel : IValidatableObject
BillingAddressCountry = BillingAddressCountry, BillingAddressCountry = BillingAddressCountry,
}, },
InitiationPath = InitiationPath, InitiationPath = InitiationPath,
SkipTrial = SkipTrial
}; };
Keys?.ToOrganizationSignup(orgSignup); Keys?.ToOrganizationSignup(orgSignup);

View File

@ -64,6 +64,7 @@ public class OrganizationResponseModel : ResponseModel
LimitItemDeletion = organization.LimitItemDeletion; LimitItemDeletion = organization.LimitItemDeletion;
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems; AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems;
UseRiskInsights = organization.UseRiskInsights; UseRiskInsights = organization.UseRiskInsights;
UseOrganizationDomains = organization.UseOrganizationDomains;
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies;
} }
@ -111,6 +112,7 @@ public class OrganizationResponseModel : ResponseModel
public bool LimitItemDeletion { get; set; } public bool LimitItemDeletion { get; set; }
public bool AllowAdminAccessToAllCollectionItems { get; set; } public bool AllowAdminAccessToAllCollectionItems { get; set; }
public bool UseRiskInsights { get; set; } public bool UseRiskInsights { get; set; }
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; } public bool UseAdminSponsoredFamilies { get; set; }
} }

View File

@ -73,6 +73,7 @@ public class ProfileOrganizationResponseModel : ResponseModel
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems; AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems;
UserIsClaimedByOrganization = organizationIdsClaimingUser.Contains(organization.OrganizationId); UserIsClaimedByOrganization = organizationIdsClaimingUser.Contains(organization.OrganizationId);
UseRiskInsights = organization.UseRiskInsights; UseRiskInsights = organization.UseRiskInsights;
UseOrganizationDomains = organization.UseOrganizationDomains;
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies;
if (organization.SsoConfig != null) if (organization.SsoConfig != null)
@ -153,6 +154,7 @@ public class ProfileOrganizationResponseModel : ResponseModel
/// </remarks> /// </remarks>
public bool UserIsClaimedByOrganization { get; set; } public bool UserIsClaimedByOrganization { get; set; }
public bool UseRiskInsights { get; set; } public bool UseRiskInsights { get; set; }
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; } public bool UseAdminSponsoredFamilies { get; set; }
public bool IsAdminInitiated { get; set; } public bool IsAdminInitiated { get; set; }
} }

View File

@ -50,6 +50,7 @@ public class ProfileProviderOrganizationResponseModel : ProfileOrganizationRespo
LimitItemDeletion = organization.LimitItemDeletion; LimitItemDeletion = organization.LimitItemDeletion;
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems; AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems;
UseRiskInsights = organization.UseRiskInsights; UseRiskInsights = organization.UseRiskInsights;
UseOrganizationDomains = organization.UseOrganizationDomains;
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies;
} }
} }

View File

@ -1,7 +1,7 @@
#nullable enable #nullable enable
using Bit.Api.Billing.Models.Responses; using Bit.Api.Billing.Models.Responses;
using Bit.Core.Billing.Models.Api.Requests.Accounts;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Requests;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;

View File

@ -1,5 +1,5 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Models.Api.Requests.Organizations; using Bit.Core.Billing.Tax.Requests;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;

View File

@ -9,6 +9,7 @@ using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
@ -291,15 +292,17 @@ public class OrganizationBillingController(
sale.Organization.PlanType = plan.Type; sale.Organization.PlanType = plan.Type;
sale.Organization.Plan = plan.Name; sale.Organization.Plan = plan.Name;
sale.SubscriptionSetup.SkipTrial = true; sale.SubscriptionSetup.SkipTrial = true;
await organizationBillingService.Finalize(sale);
if (organizationSignup.PaymentMethodType == null || string.IsNullOrEmpty(organizationSignup.PaymentToken))
{
return Error.BadRequest("A payment method is required to restart the subscription.");
}
var org = await organizationRepository.GetByIdAsync(organizationId); var org = await organizationRepository.GetByIdAsync(organizationId);
Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine."); Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine.");
if (organizationSignup.PaymentMethodType != null) var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken);
{ var taxInformation = TaxInformation.From(organizationSignup.TaxInfo);
var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken); await organizationBillingService.UpdatePaymentMethod(org, paymentSource, taxInformation);
var taxInformation = TaxInformation.From(organizationSignup.TaxInfo); await organizationBillingService.Finalize(sale);
await organizationBillingService.UpdatePaymentMethod(org, paymentSource, taxInformation);
}
return TypedResults.Ok(); return TypedResults.Ok();
} }

View File

@ -222,6 +222,20 @@ public class OrganizationSponsorshipsController : Controller
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
} }
[Authorize("Application")]
[HttpDelete("{sponsoringOrgId}/{sponsoredFriendlyName}/revoke")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task AdminInitiatedRevokeSponsorshipAsync(Guid sponsoringOrgId, string sponsoredFriendlyName)
{
var sponsorships = await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrgId);
var existingOrgSponsorship = sponsorships.FirstOrDefault(s => s.FriendlyName != null && s.FriendlyName.Equals(sponsoredFriendlyName, StringComparison.OrdinalIgnoreCase));
if (existingOrgSponsorship == null)
{
throw new BadRequestException("The specified sponsored organization could not be found under the given sponsoring organization.");
}
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
}
[Authorize("Application")] [Authorize("Application")]
[HttpDelete("sponsored/{sponsoredOrgId}")] [HttpDelete("sponsored/{sponsoredOrgId}")]
[HttpPost("sponsored/{sponsoredOrgId}/remove")] [HttpPost("sponsored/{sponsoredOrgId}/remove")]

View File

@ -109,28 +109,6 @@ public class OrganizationsController(
return license; return license;
} }
[HttpPost("{id:guid}/payment")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task PostPayment(Guid id, [FromBody] PaymentRequestModel model)
{
if (!await currentContext.EditPaymentMethods(id))
{
throw new NotFoundException();
}
await organizationService.ReplacePaymentMethodAsync(id, model.PaymentToken,
model.PaymentMethodType.Value, new TaxInfo
{
BillingAddressLine1 = model.Line1,
BillingAddressLine2 = model.Line2,
BillingAddressState = model.State,
BillingAddressCity = model.City,
BillingAddressPostalCode = model.PostalCode,
BillingAddressCountry = model.Country,
TaxIdNumber = model.TaxId,
});
}
[HttpPost("{id:guid}/upgrade")] [HttpPost("{id:guid}/upgrade")]
[SelfHosted(NotSelfHostedOnly = true)] [SelfHosted(NotSelfHostedOnly = true)]
public async Task<PaymentResponseModel> PostUpgrade(Guid id, [FromBody] OrganizationUpgradeRequestModel model) public async Task<PaymentResponseModel> PostUpgrade(Guid id, [FromBody] OrganizationUpgradeRequestModel model)

View File

@ -6,6 +6,7 @@ using Bit.Core.Billing.Models;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Models.BitStripe; using Bit.Core.Models.BitStripe;
using Bit.Core.Services; using Bit.Core.Services;

View File

@ -1,4 +1,4 @@
using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Services;
using Bit.Core.Services; using Bit.Core.Services;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Http.HttpResults;

View File

@ -0,0 +1,36 @@
using Bit.Api.Billing.Models.Requests;
using Bit.Core.Billing.Tax.Commands;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
namespace Bit.Api.Billing.Controllers;
[Authorize("Application")]
[Route("tax")]
public class TaxController(
IPreviewTaxAmountCommand previewTaxAmountCommand) : BaseBillingController
{
[HttpPost("preview-amount/organization-trial")]
public async Task<IResult> PreviewTaxAmountForOrganizationTrialAsync(
[FromBody] PreviewTaxAmountForOrganizationTrialRequestBody requestBody)
{
var parameters = new OrganizationTrialParameters
{
PlanType = requestBody.PlanType,
ProductType = requestBody.ProductType,
TaxInformation = new OrganizationTrialParameters.TaxInformationDTO
{
Country = requestBody.TaxInformation.Country,
PostalCode = requestBody.TaxInformation.PostalCode,
TaxId = requestBody.TaxInformation.TaxId
}
};
var result = await previewTaxAmountCommand.Run(parameters);
return result.Match<IResult>(
taxAmount => TypedResults.Ok(new { TaxAmount = taxAmount }),
badRequest => Error.BadRequest(badRequest.TranslationKey),
unhandled => Error.ServerError(unhandled.TranslationKey));
}
}

View File

@ -0,0 +1,27 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Enums;
namespace Bit.Api.Billing.Models.Requests;
public class PreviewTaxAmountForOrganizationTrialRequestBody
{
[Required]
public PlanType PlanType { get; set; }
[Required]
public ProductType ProductType { get; set; }
[Required] public TaxInformationDTO TaxInformation { get; set; } = null!;
public class TaxInformationDTO
{
[Required]
public string Country { get; set; } = null!;
[Required]
public string PostalCode { get; set; } = null!;
public string? TaxId { get; set; }
}
}

View File

@ -1,5 +1,5 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Tax.Models;
namespace Bit.Api.Billing.Models.Requests; namespace Bit.Api.Billing.Models.Requests;

View File

@ -1,4 +1,5 @@
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Tax.Models;
namespace Bit.Api.Billing.Models.Responses; namespace Bit.Api.Billing.Models.Responses;

View File

@ -2,6 +2,7 @@
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Tax.Models;
using Stripe; using Stripe;
namespace Bit.Api.Billing.Models.Responses; namespace Bit.Api.Billing.Models.Responses;

View File

@ -1,4 +1,4 @@
using Bit.Core.Billing.Models; using Bit.Core.Billing.Tax.Models;
namespace Bit.Api.Billing.Models.Responses; namespace Bit.Api.Billing.Models.Responses;

View File

@ -1,6 +1,10 @@
using Bit.Api.Models.Request.Organizations; using Bit.Api.AdminConsole.Authorization.Requirements;
using Bit.Api.Models.Request.Organizations;
using Bit.Api.Models.Response;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Api.Response.OrganizationSponsorships;
using Bit.Core.Models.Data.Organizations.OrganizationSponsorships;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
@ -22,6 +26,7 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand; private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand;
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
private readonly IAuthorizationService _authorizationService;
public SelfHostedOrganizationSponsorshipsController( public SelfHostedOrganizationSponsorshipsController(
ICreateSponsorshipCommand offerSponsorshipCommand, ICreateSponsorshipCommand offerSponsorshipCommand,
@ -30,7 +35,8 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationSponsorshipRepository organizationSponsorshipRepository,
IOrganizationUserRepository organizationUserRepository, IOrganizationUserRepository organizationUserRepository,
ICurrentContext currentContext, ICurrentContext currentContext,
IFeatureService featureService IFeatureService featureService,
IAuthorizationService authorizationService
) )
{ {
_offerSponsorshipCommand = offerSponsorshipCommand; _offerSponsorshipCommand = offerSponsorshipCommand;
@ -40,6 +46,7 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
_organizationUserRepository = organizationUserRepository; _organizationUserRepository = organizationUserRepository;
_currentContext = currentContext; _currentContext = currentContext;
_featureService = featureService; _featureService = featureService;
_authorizationService = authorizationService;
} }
[HttpPost("{sponsoringOrgId}/families-for-enterprise")] [HttpPost("{sponsoringOrgId}/families-for-enterprise")]
@ -84,4 +91,41 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
} }
[HttpDelete("{sponsoringOrgId}/{sponsoredFriendlyName}/revoke")]
public async Task AdminInitiatedRevokeSponsorshipAsync(Guid sponsoringOrgId, string sponsoredFriendlyName)
{
var sponsorships = await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrgId);
var existingOrgSponsorship = sponsorships.FirstOrDefault(s => s.FriendlyName != null && s.FriendlyName.Equals(sponsoredFriendlyName, StringComparison.OrdinalIgnoreCase));
if (existingOrgSponsorship == null)
{
throw new BadRequestException("The specified sponsored organization could not be found under the given sponsoring organization.");
}
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
}
[Authorize("Application")]
[HttpGet("{orgId}/sponsored")]
public async Task<ListResponseModel<OrganizationSponsorshipInvitesResponseModel>> GetSponsoredOrganizations(Guid orgId)
{
var sponsoringOrg = await _organizationRepository.GetByIdAsync(orgId);
if (sponsoringOrg == null)
{
throw new NotFoundException();
}
var authorizationResult = await _authorizationService.AuthorizeAsync(User, orgId, new ManageUsersRequirement());
if (!authorizationResult.Succeeded)
{
throw new UnauthorizedAccessException();
}
var sponsorships = await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(orgId);
return new ListResponseModel<OrganizationSponsorshipInvitesResponseModel>(
sponsorships
.Where(s => s.IsAdminInitiated)
.Select(s => new OrganizationSponsorshipInvitesResponseModel(new OrganizationSponsorshipData(s)))
);
}
} }

View File

@ -4,6 +4,7 @@ namespace Bit.Api.Tools.Models.Response;
public class MemberCipherDetailsResponseModel public class MemberCipherDetailsResponseModel
{ {
public Guid? UserGuid { get; set; }
public string UserName { get; set; } public string UserName { get; set; }
public string Email { get; set; } public string Email { get; set; }
public bool UsesKeyConnector { get; set; } public bool UsesKeyConnector { get; set; }
@ -16,6 +17,7 @@ public class MemberCipherDetailsResponseModel
public MemberCipherDetailsResponseModel(MemberAccessCipherDetails memberAccessCipherDetails) public MemberCipherDetailsResponseModel(MemberAccessCipherDetails memberAccessCipherDetails)
{ {
this.UserGuid = memberAccessCipherDetails.UserGuid;
this.UserName = memberAccessCipherDetails.UserName; this.UserName = memberAccessCipherDetails.UserName;
this.Email = memberAccessCipherDetails.Email; this.Email = memberAccessCipherDetails.Email;
this.UsesKeyConnector = memberAccessCipherDetails.UsesKeyConnector; this.UsesKeyConnector = memberAccessCipherDetails.UsesKeyConnector;

View File

@ -32,6 +32,7 @@ public class PlanResponseModel : ResponseModel
HasTotp = plan.HasTotp; HasTotp = plan.HasTotp;
Has2fa = plan.Has2fa; Has2fa = plan.Has2fa;
HasSso = plan.HasSso; HasSso = plan.HasSso;
HasOrganizationDomains = plan.HasOrganizationDomains;
HasResetPassword = plan.HasResetPassword; HasResetPassword = plan.HasResetPassword;
UsersGetPremium = plan.UsersGetPremium; UsersGetPremium = plan.UsersGetPremium;
UpgradeSortOrder = plan.UpgradeSortOrder; UpgradeSortOrder = plan.UpgradeSortOrder;
@ -71,6 +72,7 @@ public class PlanResponseModel : ResponseModel
public bool Has2fa { get; set; } public bool Has2fa { get; set; }
public bool HasApi { get; set; } public bool HasApi { get; set; }
public bool HasSso { get; set; } public bool HasSso { get; set; }
public bool HasOrganizationDomains { get; set; }
public bool HasResetPassword { get; set; } public bool HasResetPassword { get; set; }
public bool UsersGetPremium { get; set; } public bool UsersGetPremium { get; set; }

View File

@ -315,26 +315,10 @@ public class CiphersController : Controller
{ {
var org = _currentContext.GetOrganization(organizationId); var org = _currentContext.GetOrganization(organizationId);
// If we're not an "admin", we don't need to check the ciphers // If we're not an "admin" or if we're not a provider user we don't need to check the ciphers
if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true })) if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true }) || await _currentContext.ProviderUserForOrgAsync(organizationId))
{ {
// Are we a provider user? If so, we need to be sure we're not restricted return false;
// Once the feature flag is removed, this check can be combined with the above
if (await _currentContext.ProviderUserForOrgAsync(organizationId))
{
// Provider is restricted from editing ciphers, so we're not an "admin"
if (_featureService.IsEnabled(FeatureFlagKeys.RestrictProviderAccess))
{
return false;
}
// Provider is unrestricted, so we're an "admin", don't return early
}
else
{
// Not a provider or admin
return false;
}
} }
// We know we're an "admin", now check the ciphers explicitly (in case admins are restricted) // We know we're an "admin", now check the ciphers explicitly (in case admins are restricted)
@ -350,26 +334,10 @@ public class CiphersController : Controller
var org = _currentContext.GetOrganization(organizationId); var org = _currentContext.GetOrganization(organizationId);
// If we're not an "admin", we don't need to check the ciphers // If we're not an "admin" or if we're a provider user we don't need to check the ciphers
if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true })) if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true }) || await _currentContext.ProviderUserForOrgAsync(organizationId))
{ {
// Are we a provider user? If so, we need to be sure we're not restricted return false;
// Once the feature flag is removed, this check can be combined with the above
if (await _currentContext.ProviderUserForOrgAsync(organizationId))
{
// Provider is restricted from editing ciphers, so we're not an "admin"
if (_featureService.IsEnabled(FeatureFlagKeys.RestrictProviderAccess))
{
return false;
}
// Provider is unrestricted, so we're an "admin", don't return early
}
else
{
// Not a provider or admin
return false;
}
} }
// If the user can edit all ciphers for the organization, just check they all belong to the org // If the user can edit all ciphers for the organization, just check they all belong to the org
@ -462,10 +430,10 @@ public class CiphersController : Controller
return true; return true;
} }
// Provider users can edit all ciphers if RestrictProviderAccess is disabled // Provider users cannot edit ciphers
if (await _currentContext.ProviderUserForOrgAsync(organizationId)) if (await _currentContext.ProviderUserForOrgAsync(organizationId))
{ {
return !_featureService.IsEnabled(FeatureFlagKeys.RestrictProviderAccess); return false;
} }
return false; return false;
@ -485,10 +453,10 @@ public class CiphersController : Controller
return true; return true;
} }
// Provider users can only access organization ciphers if RestrictProviderAccess is disabled // Provider users cannot access organization ciphers
if (await _currentContext.ProviderUserForOrgAsync(organizationId)) if (await _currentContext.ProviderUserForOrgAsync(organizationId))
{ {
return !_featureService.IsEnabled(FeatureFlagKeys.RestrictProviderAccess); return false;
} }
return false; return false;
@ -508,10 +476,10 @@ public class CiphersController : Controller
return true; return true;
} }
// Provider users can only access all ciphers if RestrictProviderAccess is disabled // Provider users cannot access ciphers
if (await _currentContext.ProviderUserForOrgAsync(organizationId)) if (await _currentContext.ProviderUserForOrgAsync(organizationId))
{ {
return !_featureService.IsEnabled(FeatureFlagKeys.RestrictProviderAccess); return false;
} }
return false; return false;

View File

@ -63,6 +63,12 @@ public class FreshdeskController : Controller
note += $"<li>Region: {_billingSettings.FreshDesk.Region}</li>"; note += $"<li>Region: {_billingSettings.FreshDesk.Region}</li>";
var customFields = new Dictionary<string, object>(); var customFields = new Dictionary<string, object>();
var user = await _userRepository.GetByEmailAsync(ticketContactEmail); var user = await _userRepository.GetByEmailAsync(ticketContactEmail);
if (user == null)
{
note += $"<li>No user found: {ticketContactEmail}</li>";
await CreateNote(ticketId, note);
}
if (user != null) if (user != null)
{ {
var userLink = $"{_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}"; var userLink = $"{_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}";
@ -121,18 +127,7 @@ public class FreshdeskController : Controller
Content = JsonContent.Create(updateBody), Content = JsonContent.Create(updateBody),
}; };
await CallFreshdeskApiAsync(updateRequest); await CallFreshdeskApiAsync(updateRequest);
await CreateNote(ticketId, note);
var noteBody = new Dictionary<string, object>
{
{ "body", $"<ul>{note}</ul>" },
{ "private", true }
};
var noteRequest = new HttpRequestMessage(HttpMethod.Post,
string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}/notes", ticketId))
{
Content = JsonContent.Create(noteBody),
};
await CallFreshdeskApiAsync(noteRequest);
} }
return new OkResult(); return new OkResult();
@ -208,6 +203,21 @@ public class FreshdeskController : Controller
return true; return true;
} }
private async Task CreateNote(string ticketId, string note)
{
var noteBody = new Dictionary<string, object>
{
{ "body", $"<ul>{note}</ul>" },
{ "private", true }
};
var noteRequest = new HttpRequestMessage(HttpMethod.Post,
string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}/notes", ticketId))
{
Content = JsonContent.Create(noteBody),
};
await CallFreshdeskApiAsync(noteRequest);
}
private async Task AddAnswerNoteToTicketAsync(string note, string ticketId) private async Task AddAnswerNoteToTicketAsync(string note, string ticketId)
{ {
// if there is no content, then we don't need to add a note // if there is no content, then we don't need to add a note

View File

@ -1,11 +1,11 @@
using Bit.Core; using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
@ -25,8 +25,7 @@ public class UpcomingInvoiceHandler(
IStripeEventService stripeEventService, IStripeEventService stripeEventService,
IStripeEventUtilityService stripeEventUtilityService, IStripeEventUtilityService stripeEventUtilityService,
IUserRepository userRepository, IUserRepository userRepository,
IValidateSponsorshipCommand validateSponsorshipCommand, IValidateSponsorshipCommand validateSponsorshipCommand)
IAutomaticTaxFactory automaticTaxFactory)
: IUpcomingInvoiceHandler : IUpcomingInvoiceHandler
{ {
public async Task HandleAsync(Event parsedEvent) public async Task HandleAsync(Event parsedEvent)
@ -46,6 +45,8 @@ public class UpcomingInvoiceHandler(
var (organizationId, userId, providerId) = stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata); var (organizationId, userId, providerId) = stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata);
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (organizationId.HasValue) if (organizationId.HasValue)
{ {
var organization = await organizationRepository.GetByIdAsync(organizationId.Value); var organization = await organizationRepository.GetByIdAsync(organizationId.Value);
@ -55,7 +56,7 @@ public class UpcomingInvoiceHandler(
return; return;
} }
await TryEnableAutomaticTaxAsync(subscription); await AlignOrganizationTaxConcernsAsync(organization, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge);
var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); var plan = await pricingClient.GetPlanOrThrow(organization.PlanType);
@ -100,7 +101,25 @@ public class UpcomingInvoiceHandler(
return; return;
} }
await TryEnableAutomaticTaxAsync(subscription); if (!subscription.AutomaticTax.Enabled && subscription.Customer.HasRecognizedTaxLocation())
{
try
{
await stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set user's ({UserID}) subscription to automatic tax while processing event with ID {EventID}",
user.Id,
parsedEvent.Id);
}
}
if (user.Premium) if (user.Premium)
{ {
@ -116,7 +135,7 @@ public class UpcomingInvoiceHandler(
return; return;
} }
await TryEnableAutomaticTaxAsync(subscription); await AlignProviderTaxConcernsAsync(provider, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge);
await SendUpcomingInvoiceEmailsAsync(new List<string> { provider.BillingEmail }, invoice); await SendUpcomingInvoiceEmailsAsync(new List<string> { provider.BillingEmail }, invoice);
} }
@ -139,50 +158,123 @@ public class UpcomingInvoiceHandler(
} }
} }
private async Task TryEnableAutomaticTaxAsync(Subscription subscription) private async Task AlignOrganizationTaxConcernsAsync(
Organization organization,
Subscription subscription,
string eventId,
bool setNonUSBusinessUseToReverseCharge)
{ {
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var nonUSBusinessUse =
{ organization.PlanType.GetProductTier() != ProductTierType.Families &&
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id)); subscription.Customer.Address.Country != "US";
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
if (updateOptions == null) bool setAutomaticTaxToEnabled;
if (setNonUSBusinessUseToReverseCharge)
{
if (nonUSBusinessUse && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse)
{ {
return; try
{
await stripeFacade.UpdateCustomer(subscription.CustomerId,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}",
organization.Id,
eventId);
}
} }
await stripeFacade.UpdateSubscription(subscription.Id, updateOptions); setAutomaticTaxToEnabled = true;
return;
} }
else
if (subscription.AutomaticTax.Enabled ||
!subscription.Customer.HasBillingLocation() ||
await IsNonTaxableNonUSBusinessUseSubscription(subscription))
{ {
return; setAutomaticTaxToEnabled =
subscription.Customer.HasRecognizedTaxLocation() &&
(subscription.Customer.Address.Country == "US" ||
(nonUSBusinessUse && subscription.Customer.TaxIds.Any()));
} }
await stripeFacade.UpdateSubscription(subscription.Id, if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled)
new SubscriptionUpdateOptions {
try
{ {
DefaultTaxRates = [], await stripeFacade.UpdateSubscription(subscription.Id,
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } new SubscriptionUpdateOptions
}); {
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set organization's ({OrganizationID}) subscription to automatic tax while processing event with ID {EventID}",
organization.Id,
eventId);
}
}
}
return; private async Task AlignProviderTaxConcernsAsync(
Provider provider,
Subscription subscription,
string eventId,
bool setNonUSBusinessUseToReverseCharge)
{
bool setAutomaticTaxToEnabled;
async Task<bool> IsNonTaxableNonUSBusinessUseSubscription(Subscription localSubscription) if (setNonUSBusinessUseToReverseCharge)
{ {
var familyPriceIds = (await Task.WhenAll( if (subscription.Customer.Address.Country != "US" && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse)
pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), {
pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually))) try
.Select(plan => plan.PasswordManager.StripePlanId); {
await stripeFacade.UpdateCustomer(subscription.CustomerId,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}",
provider.Id,
eventId);
}
}
return localSubscription.Customer.Address.Country != "US" && setAutomaticTaxToEnabled = true;
localSubscription.Metadata.ContainsKey(StripeConstants.MetadataKeys.OrganizationId) && }
!localSubscription.Items.Select(item => item.Price.Id).Intersect(familyPriceIds).Any() && else
!localSubscription.Customer.TaxIds.Any(); {
setAutomaticTaxToEnabled =
subscription.Customer.HasRecognizedTaxLocation() &&
(subscription.Customer.Address.Country == "US" ||
subscription.Customer.TaxIds.Any());
}
if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled)
{
try
{
await stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set provider's ({ProviderID}) subscription to automatic tax while processing event with ID {EventID}",
provider.Id,
eventId);
}
} }
} }
} }

View File

@ -114,6 +114,11 @@ public class Organization : ITableObject<Guid>, IStorableSubscriber, IRevisable,
/// </summary> /// </summary>
public bool UseRiskInsights { get; set; } public bool UseRiskInsights { get; set; }
/// <summary>
/// If true, the organization can claim domains, which unlocks additional enterprise features
/// </summary>
public bool UseOrganizationDomains { get; set; }
/// <summary> /// <summary>
/// If set to true, admins can initiate organization-issued sponsorships. /// If set to true, admins can initiate organization-issued sponsorships.
/// </summary> /// </summary>
@ -319,5 +324,7 @@ public class Organization : ITableObject<Guid>, IStorableSubscriber, IRevisable,
SmSeats = license.SmSeats; SmSeats = license.SmSeats;
SmServiceAccounts = license.SmServiceAccounts; SmServiceAccounts = license.SmServiceAccounts;
UseRiskInsights = license.UseRiskInsights; UseRiskInsights = license.UseRiskInsights;
UseOrganizationDomains = license.UseOrganizationDomains;
UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies;
} }
} }

View File

@ -26,6 +26,7 @@ public class OrganizationAbility
LimitItemDeletion = organization.LimitItemDeletion; LimitItemDeletion = organization.LimitItemDeletion;
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems; AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems;
UseRiskInsights = organization.UseRiskInsights; UseRiskInsights = organization.UseRiskInsights;
UseOrganizationDomains = organization.UseOrganizationDomains;
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies;
} }
@ -46,5 +47,6 @@ public class OrganizationAbility
public bool LimitItemDeletion { get; set; } public bool LimitItemDeletion { get; set; }
public bool AllowAdminAccessToAllCollectionItems { get; set; } public bool AllowAdminAccessToAllCollectionItems { get; set; }
public bool UseRiskInsights { get; set; } public bool UseRiskInsights { get; set; }
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; } public bool UseAdminSponsoredFamilies { get; set; }
} }

View File

@ -59,6 +59,7 @@ public class OrganizationUserOrganizationDetails
public bool LimitItemDeletion { get; set; } public bool LimitItemDeletion { get; set; }
public bool AllowAdminAccessToAllCollectionItems { get; set; } public bool AllowAdminAccessToAllCollectionItems { get; set; }
public bool UseRiskInsights { get; set; } public bool UseRiskInsights { get; set; }
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; } public bool UseAdminSponsoredFamilies { get; set; }
public bool? IsAdminInitiated { get; set; } public bool? IsAdminInitiated { get; set; }
} }

View File

@ -150,6 +150,7 @@ public class SelfHostedOrganizationDetails : Organization
AllowAdminAccessToAllCollectionItems = AllowAdminAccessToAllCollectionItems, AllowAdminAccessToAllCollectionItems = AllowAdminAccessToAllCollectionItems,
Status = Status, Status = Status,
UseRiskInsights = UseRiskInsights, UseRiskInsights = UseRiskInsights,
UseAdminSponsoredFamilies = UseAdminSponsoredFamilies,
}; };
} }
} }

View File

@ -45,6 +45,7 @@ public class ProviderUserOrganizationDetails
public bool LimitItemDeletion { get; set; } public bool LimitItemDeletion { get; set; }
public bool AllowAdminAccessToAllCollectionItems { get; set; } public bool AllowAdminAccessToAllCollectionItems { get; set; }
public bool UseRiskInsights { get; set; } public bool UseRiskInsights { get; set; }
public bool UseOrganizationDomains { get; set; }
public bool UseAdminSponsoredFamilies { get; set; } public bool UseAdminSponsoredFamilies { get; set; }
public ProviderType ProviderType { get; set; } public ProviderType ProviderType { get; set; }
} }

View File

@ -24,9 +24,7 @@ public class GetOrganizationUsersClaimedStatusQuery : IGetOrganizationUsersClaim
// Users can only be claimed by an Organization that is enabled and can have organization domains // Users can only be claimed by an Organization that is enabled and can have organization domains
var organizationAbility = await _applicationCacheService.GetOrganizationAbilityAsync(organizationId); var organizationAbility = await _applicationCacheService.GetOrganizationAbilityAsync(organizationId);
// TODO: Replace "UseSso" with a new organization ability like "UseOrganizationDomains" (PM-11622). if (organizationAbility is { Enabled: true, UseOrganizationDomains: true })
// Verified domains were tied to SSO, so we currently check the "UseSso" organization ability.
if (organizationAbility is { Enabled: true, UseSso: true })
{ {
// Get all organization users with claimed domains by the organization // Get all organization users with claimed domains by the organization
var organizationUsersWithClaimedDomain = await _organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organizationId); var organizationUsersWithClaimedDomain = await _organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organizationId);

View File

@ -104,7 +104,8 @@ public class CloudOrganizationSignUpCommand(
RevisionDate = DateTime.UtcNow, RevisionDate = DateTime.UtcNow,
Status = OrganizationStatusType.Created, Status = OrganizationStatusType.Created,
UsePasswordManager = true, UsePasswordManager = true,
UseSecretsManager = signup.UseSecretsManager UseSecretsManager = signup.UseSecretsManager,
UseOrganizationDomains = plan.HasOrganizationDomains,
}; };
if (signup.UseSecretsManager) if (signup.UseSecretsManager)

View File

@ -11,8 +11,6 @@ namespace Bit.Core.Services;
public interface IOrganizationService public interface IOrganizationService
{ {
Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType,
TaxInfo taxInfo);
Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null); Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null);
Task ReinstateSubscriptionAsync(Guid organizationId); Task ReinstateSubscriptionAsync(Guid organizationId);
Task<string> AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb); Task<string> AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb);

View File

@ -144,27 +144,6 @@ public class OrganizationService : IOrganizationService
_sendOrganizationInvitesCommand = sendOrganizationInvitesCommand; _sendOrganizationInvitesCommand = sendOrganizationInvitesCommand;
} }
public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken,
PaymentMethodType paymentMethodType, TaxInfo taxInfo)
{
var organization = await GetOrgById(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
await _paymentService.SaveTaxInfoAsync(organization, taxInfo);
var updated = await _paymentService.UpdatePaymentMethodAsync(
organization,
paymentMethodType,
paymentToken,
taxInfo);
if (updated)
{
await ReplaceAndUpdateCacheAsync(organization);
}
}
public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null) public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null)
{ {
var organization = await GetOrgById(organizationId); var organization = await GetOrgById(organizationId);
@ -449,6 +428,7 @@ public class OrganizationService : IOrganizationService
MaxStorageGb = 1, MaxStorageGb = 1,
UsePolicies = plan.HasPolicies, UsePolicies = plan.HasPolicies,
UseSso = plan.HasSso, UseSso = plan.HasSso,
UseOrganizationDomains = plan.HasOrganizationDomains,
UseGroups = plan.HasGroups, UseGroups = plan.HasGroups,
UseEvents = plan.HasEvents, UseEvents = plan.HasEvents,
UseDirectory = plan.HasDirectory, UseDirectory = plan.HasDirectory,
@ -570,6 +550,8 @@ public class OrganizationService : IOrganizationService
SmSeats = license.SmSeats, SmSeats = license.SmSeats,
SmServiceAccounts = license.SmServiceAccounts, SmServiceAccounts = license.SmServiceAccounts,
UseRiskInsights = license.UseRiskInsights, UseRiskInsights = license.UseRiskInsights,
UseOrganizationDomains = license.UseOrganizationDomains,
UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies,
}; };
var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false); var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false);

View File

@ -2,9 +2,24 @@
public enum EmergencyAccessStatusType : byte public enum EmergencyAccessStatusType : byte
{ {
/// <summary>
/// The user has been invited to be an emergency contact.
/// </summary>
Invited = 0, Invited = 0,
/// <summary>
/// The invited user, "grantee", has accepted the request to be an emergency contact.
/// </summary>
Accepted = 1, Accepted = 1,
/// <summary>
/// The inviting user, "grantor", has approved the grantee's acceptance.
/// </summary>
Confirmed = 2, Confirmed = 2,
/// <summary>
/// The grantee has initiated the recovery process.
/// </summary>
RecoveryInitiated = 3, RecoveryInitiated = 3,
/// <summary>
/// The grantee has excercised their emergency access.
/// </summary>
RecoveryApproved = 4, RecoveryApproved = 4,
} }

View File

@ -3,6 +3,7 @@ using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums; using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Models.Data;
@ -20,6 +21,15 @@ public interface IEmergencyAccessService
Task InitiateAsync(Guid id, User initiatingUser); Task InitiateAsync(Guid id, User initiatingUser);
Task ApproveAsync(Guid id, User approvingUser); Task ApproveAsync(Guid id, User approvingUser);
Task RejectAsync(Guid id, User rejectingUser); Task RejectAsync(Guid id, User rejectingUser);
/// <summary>
/// This request is made by the Grantee user to fetch the policies <see cref="Policy"/> for the Grantor User.
/// The Grantor User has to be the owner of the organization. <see cref="OrganizationUserType"/>
/// If the Grantor user has OrganizationUserType.Owner then the policies for the _Grantor_ user
/// are returned.
/// </summary>
/// <param name="id">EmergencyAccess.Id being acted on</param>
/// <param name="requestingUser">User making the request, this is the Grantee</param>
/// <returns>null if the GrantorUser is not an organization owner; A list of policies otherwise.</returns>
Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser); Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser);
Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser); Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser);
Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key); Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key);

View File

@ -3,7 +3,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities; using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums; using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities; using Bit.Core.Entities;
@ -16,7 +15,6 @@ using Bit.Core.Tokens;
using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Models.Data;
using Bit.Core.Vault.Repositories; using Bit.Core.Vault.Repositories;
using Bit.Core.Vault.Services; using Bit.Core.Vault.Services;
using Microsoft.AspNetCore.Identity;
namespace Bit.Core.Auth.Services; namespace Bit.Core.Auth.Services;
@ -31,8 +29,6 @@ public class EmergencyAccessService : IEmergencyAccessService
private readonly IMailService _mailService; private readonly IMailService _mailService;
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly GlobalSettings _globalSettings; private readonly GlobalSettings _globalSettings;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IOrganizationService _organizationService;
private readonly IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> _dataProtectorTokenizer; private readonly IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> _dataProtectorTokenizer;
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
@ -45,9 +41,7 @@ public class EmergencyAccessService : IEmergencyAccessService
ICipherService cipherService, ICipherService cipherService,
IMailService mailService, IMailService mailService,
IUserService userService, IUserService userService,
IPasswordHasher<User> passwordHasher,
GlobalSettings globalSettings, GlobalSettings globalSettings,
IOrganizationService organizationService,
IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> dataProtectorTokenizer, IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> dataProtectorTokenizer,
IRemoveOrganizationUserCommand removeOrganizationUserCommand) IRemoveOrganizationUserCommand removeOrganizationUserCommand)
{ {
@ -59,9 +53,7 @@ public class EmergencyAccessService : IEmergencyAccessService
_cipherService = cipherService; _cipherService = cipherService;
_mailService = mailService; _mailService = mailService;
_userService = userService; _userService = userService;
_passwordHasher = passwordHasher;
_globalSettings = globalSettings; _globalSettings = globalSettings;
_organizationService = organizationService;
_dataProtectorTokenizer = dataProtectorTokenizer; _dataProtectorTokenizer = dataProtectorTokenizer;
_removeOrganizationUserCommand = removeOrganizationUserCommand; _removeOrganizationUserCommand = removeOrganizationUserCommand;
} }
@ -126,7 +118,12 @@ public class EmergencyAccessService : IEmergencyAccessService
throw new BadRequestException("Emergency Access not valid."); throw new BadRequestException("Emergency Access not valid.");
} }
if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email)) if (!_dataProtectorTokenizer.TryUnprotect(token, out var data))
{
throw new BadRequestException("Invalid token.");
}
if (!data.IsValid(emergencyAccessId, user.Email))
{ {
throw new BadRequestException("Invalid token."); throw new BadRequestException("Invalid token.");
} }
@ -140,6 +137,8 @@ public class EmergencyAccessService : IEmergencyAccessService
throw new BadRequestException("Invitation already accepted."); throw new BadRequestException("Invitation already accepted.");
} }
// TODO PM-21687
// Might not be reachable since the Tokenable.IsValid() does an email comparison
if (string.IsNullOrWhiteSpace(emergencyAccess.Email) || if (string.IsNullOrWhiteSpace(emergencyAccess.Email) ||
!emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) !emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
{ {
@ -163,6 +162,8 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId) public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId)
{ {
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
// TODO PM-19438/PM-21687
// Not sure why the GrantorId and the GranteeId are supposed to be the same?
if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId)) if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId))
{ {
throw new BadRequestException("Emergency Access not valid."); throw new BadRequestException("Emergency Access not valid.");
@ -171,9 +172,9 @@ public class EmergencyAccessService : IEmergencyAccessService
await _emergencyAccessRepository.DeleteAsync(emergencyAccess); await _emergencyAccessRepository.DeleteAsync(emergencyAccess);
} }
public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId) public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAccessId, string key, Guid confirmingUserId)
{ {
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId); var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted || if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted ||
emergencyAccess.GrantorId != confirmingUserId) emergencyAccess.GrantorId != confirmingUserId)
{ {
@ -224,7 +225,6 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task InitiateAsync(Guid id, User initiatingUser) public async Task InitiateAsync(Guid id, User initiatingUser)
{ {
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id || if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id ||
emergencyAccess.Status != EmergencyAccessStatusType.Confirmed) emergencyAccess.Status != EmergencyAccessStatusType.Confirmed)
{ {
@ -285,6 +285,9 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser) public async Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser)
{ {
// TODO PM-21687
// Should we look up policies here or just verify the EmergencyAccess is correct
// and handle policy logic else where? Should this be a query/Command?
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
@ -295,7 +298,9 @@ public class EmergencyAccessService : IEmergencyAccessService
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id);
var isOrganizationOwner = grantorOrganizations.Any<OrganizationUser>(organization => organization.Type == OrganizationUserType.Owner); var isOrganizationOwner = grantorOrganizations
.Any(organization => organization.Type == OrganizationUserType.Owner);
var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null; var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null;
return policies; return policies;
@ -311,7 +316,8 @@ public class EmergencyAccessService : IEmergencyAccessService
} }
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
// TODO PM-21687
// Redundant check of the EmergencyAccessType -> checked in IsValidRequest() ln 308
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
{ {
throw new BadRequestException("You cannot takeover an account that is using Key Connector."); throw new BadRequestException("You cannot takeover an account that is using Key Connector.");
@ -336,7 +342,9 @@ public class EmergencyAccessService : IEmergencyAccessService
grantor.LastPasswordChangeDate = grantor.RevisionDate; grantor.LastPasswordChangeDate = grantor.RevisionDate;
grantor.Key = key; grantor.Key = key;
// Disable TwoFactor providers since they will otherwise block logins // Disable TwoFactor providers since they will otherwise block logins
grantor.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>()); grantor.SetTwoFactorProviders([]);
// Disable New Device Verification since it will otherwise block logins
grantor.VerifyDevices = false;
await _userRepository.ReplaceAsync(grantor); await _userRepository.ReplaceAsync(grantor);
// Remove grantor from all organizations unless Owner // Remove grantor from all organizations unless Owner
@ -421,12 +429,22 @@ public class EmergencyAccessService : IEmergencyAccessService
await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token); await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token);
} }
private string NameOrEmail(User user) private static string NameOrEmail(User user)
{ {
return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name; return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name;
} }
private bool IsValidRequest(EmergencyAccess availableAccess, User requestingUser, EmergencyAccessType requestedAccessType)
/*
* Checks if EmergencyAccess Object is null
* Checks the requesting user is the same as the granteeUser (So we are checking for proper grantee action)
* Status _must_ equal RecoveryApproved (This means the grantor has invited, the grantee has accepted, and the grantor has approved so the shared key exists but hasn't been exercised yet)
* request type must equal the type of access requested (View or Takeover)
*/
private static bool IsValidRequest(
EmergencyAccess availableAccess,
User requestingUser,
EmergencyAccessType requestedAccessType)
{ {
return availableAccess != null && return availableAccess != null &&
availableAccess.GranteeId == requestingUser.Id && availableAccess.GranteeId == requestingUser.Id &&

View File

@ -108,6 +108,7 @@ public class RegisterUserCommand : IRegisterUserCommand
var result = await _userService.CreateUserAsync(user, masterPasswordHash); var result = await _userService.CreateUserAsync(user, masterPasswordHash);
if (result == IdentityResult.Success) if (result == IdentityResult.Success)
{ {
var sentWelcomeEmail = false;
if (!string.IsNullOrEmpty(user.ReferenceData)) if (!string.IsNullOrEmpty(user.ReferenceData))
{ {
var referenceData = JsonConvert.DeserializeObject<Dictionary<string, object>>(user.ReferenceData); var referenceData = JsonConvert.DeserializeObject<Dictionary<string, object>>(user.ReferenceData);
@ -115,6 +116,7 @@ public class RegisterUserCommand : IRegisterUserCommand
{ {
var initiationPath = value.ToString(); var initiationPath = value.ToString();
await SendAppropriateWelcomeEmailAsync(user, initiationPath); await SendAppropriateWelcomeEmailAsync(user, initiationPath);
sentWelcomeEmail = true;
if (!string.IsNullOrEmpty(initiationPath)) if (!string.IsNullOrEmpty(initiationPath))
{ {
await _referenceEventService.RaiseEventAsync( await _referenceEventService.RaiseEventAsync(
@ -128,6 +130,11 @@ public class RegisterUserCommand : IRegisterUserCommand
} }
} }
if (!sentWelcomeEmail)
{
await _mailService.SendWelcomeEmailAsync(user);
}
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user, _currentContext)); await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user, _currentContext));
} }

View File

@ -2,10 +2,6 @@
public static class StripeConstants public static class StripeConstants
{ {
public static class Prices
{
public const string StoragePlanPersonal = "personal-storage-gb-annually";
}
public static class AutomaticTaxStatus public static class AutomaticTaxStatus
{ {
public const string Failed = "failed"; public const string Failed = "failed";
@ -69,6 +65,11 @@ public static class StripeConstants
public const string USBankAccount = "us_bank_account"; public const string USBankAccount = "us_bank_account";
} }
public static class Prices
{
public const string StoragePlanPersonal = "personal-storage-gb-annually";
}
public static class ProrationBehavior public static class ProrationBehavior
{ {
public const string AlwaysInvoice = "always_invoice"; public const string AlwaysInvoice = "always_invoice";
@ -88,6 +89,13 @@ public static class StripeConstants
public const string Paused = "paused"; public const string Paused = "paused";
} }
public static class TaxExempt
{
public const string Exempt = "exempt";
public const string None = "none";
public const string Reverse = "reverse";
}
public static class ValidateTaxLocationTiming public static class ValidateTaxLocationTiming
{ {
public const string Deferred = "deferred"; public const string Deferred = "deferred";

View File

@ -15,12 +15,7 @@ public static class CustomerExtensions
} }
}; };
/// <summary> public static bool HasRecognizedTaxLocation(this Customer customer) =>
/// Determines if a Stripe customer supports automatic tax
/// </summary>
/// <param name="customer"></param>
/// <returns></returns>
public static bool HasTaxLocationVerified(this Customer customer) =>
customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation; customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation;
public static decimal GetBillingBalance(this Customer customer) public static decimal GetBillingBalance(this Customer customer)

View File

@ -4,7 +4,9 @@ using Bit.Core.Billing.Licenses.Extensions;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Billing.Tax.Commands;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Billing.Tax.Services.Implementations;
namespace Bit.Core.Billing.Extensions; namespace Bit.Core.Billing.Extensions;
@ -24,5 +26,6 @@ public static class ServiceCollectionExtensions
services.AddTransient<IAutomaticTaxFactory, AutomaticTaxFactory>(); services.AddTransient<IAutomaticTaxFactory, AutomaticTaxFactory>();
services.AddLicenseServices(); services.AddLicenseServices();
services.AddPricingClient(); services.AddPricingClient();
services.AddTransient<IPreviewTaxAmountCommand, PreviewTaxAmountCommand>();
} }
} }

View File

@ -22,7 +22,7 @@ public static class SubscriptionUpdateOptionsExtensions
} }
// We might only need to check the automatic tax status. // We might only need to check the automatic tax status.
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country)) if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
{ {
return false; return false;
} }

View File

@ -22,7 +22,7 @@ public static class UpcomingInvoiceOptionsExtensions
} }
// We might only need to check the automatic tax status. // We might only need to check the automatic tax status.
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country)) if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
{ {
return false; return false;
} }

View File

@ -34,7 +34,6 @@ public static class OrganizationLicenseConstants
public const string UseSecretsManager = nameof(UseSecretsManager); public const string UseSecretsManager = nameof(UseSecretsManager);
public const string SmSeats = nameof(SmSeats); public const string SmSeats = nameof(SmSeats);
public const string SmServiceAccounts = nameof(SmServiceAccounts); public const string SmServiceAccounts = nameof(SmServiceAccounts);
public const string SmMaxProjects = nameof(SmMaxProjects);
public const string LimitCollectionCreationDeletion = nameof(LimitCollectionCreationDeletion); public const string LimitCollectionCreationDeletion = nameof(LimitCollectionCreationDeletion);
public const string AllowAdminAccessToAllCollectionItems = nameof(AllowAdminAccessToAllCollectionItems); public const string AllowAdminAccessToAllCollectionItems = nameof(AllowAdminAccessToAllCollectionItems);
public const string UseRiskInsights = nameof(UseRiskInsights); public const string UseRiskInsights = nameof(UseRiskInsights);
@ -43,6 +42,7 @@ public static class OrganizationLicenseConstants
public const string ExpirationWithoutGracePeriod = nameof(ExpirationWithoutGracePeriod); public const string ExpirationWithoutGracePeriod = nameof(ExpirationWithoutGracePeriod);
public const string Trial = nameof(Trial); public const string Trial = nameof(Trial);
public const string UseAdminSponsoredFamilies = nameof(UseAdminSponsoredFamilies); public const string UseAdminSponsoredFamilies = nameof(UseAdminSponsoredFamilies);
public const string UseOrganizationDomains = nameof(UseOrganizationDomains);
} }
public static class UserLicenseConstants public static class UserLicenseConstants

View File

@ -7,5 +7,4 @@ public class LicenseContext
{ {
public Guid? InstallationId { get; init; } public Guid? InstallationId { get; init; }
public required SubscriptionInfo SubscriptionInfo { get; init; } public required SubscriptionInfo SubscriptionInfo { get; init; }
public int? SmMaxProjects { get; set; }
} }

View File

@ -54,6 +54,7 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory<Organizati
new(nameof(OrganizationLicenseConstants.ExpirationWithoutGracePeriod), expirationWithoutGracePeriod.ToString(CultureInfo.InvariantCulture)), new(nameof(OrganizationLicenseConstants.ExpirationWithoutGracePeriod), expirationWithoutGracePeriod.ToString(CultureInfo.InvariantCulture)),
new(nameof(OrganizationLicenseConstants.Trial), trial.ToString()), new(nameof(OrganizationLicenseConstants.Trial), trial.ToString()),
new(nameof(OrganizationLicenseConstants.UseAdminSponsoredFamilies), entity.UseAdminSponsoredFamilies.ToString()), new(nameof(OrganizationLicenseConstants.UseAdminSponsoredFamilies), entity.UseAdminSponsoredFamilies.ToString()),
new(nameof(OrganizationLicenseConstants.UseOrganizationDomains), entity.UseOrganizationDomains.ToString()),
}; };
if (entity.Name is not null) if (entity.Name is not null)
@ -112,11 +113,6 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory<Organizati
} }
claims.Add(new Claim(nameof(OrganizationLicenseConstants.UseAdminSponsoredFamilies), entity.UseAdminSponsoredFamilies.ToString())); claims.Add(new Claim(nameof(OrganizationLicenseConstants.UseAdminSponsoredFamilies), entity.UseAdminSponsoredFamilies.ToString()));
if (licenseContext.SmMaxProjects.HasValue)
{
claims.Add(new Claim(nameof(OrganizationLicenseConstants.SmMaxProjects), licenseContext.SmMaxProjects.ToString()));
}
return Task.FromResult(claims); return Task.FromResult(claims);
} }

View File

@ -309,6 +309,7 @@ public class OrganizationMigrator(
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.UsePolicies = plan.HasPolicies; organization.UsePolicies = plan.HasPolicies;
organization.UseSso = plan.HasSso; organization.UseSso = plan.HasSso;
organization.UseOrganizationDomains = plan.HasOrganizationDomains;
organization.UseGroups = plan.HasGroups; organization.UseGroups = plan.HasGroups;
organization.UseEvents = plan.HasEvents; organization.UseEvents = plan.HasEvents;
organization.UseDirectory = plan.HasDirectory; organization.UseDirectory = plan.HasDirectory;

View File

@ -0,0 +1,36 @@
using OneOf;
namespace Bit.Core.Billing.Models;
public record BadRequest(string TranslationKey)
{
public static BadRequest TaxIdNumberInvalid => new(BillingErrorTranslationKeys.TaxIdInvalid);
public static BadRequest TaxLocationInvalid => new(BillingErrorTranslationKeys.CustomerTaxLocationInvalid);
public static BadRequest UnknownTaxIdType => new(BillingErrorTranslationKeys.UnknownTaxIdType);
}
public record Unhandled(string TranslationKey = BillingErrorTranslationKeys.UnhandledError);
public class BillingCommandResult<T> : OneOfBase<T, BadRequest, Unhandled>
{
private BillingCommandResult(OneOf<T, BadRequest, Unhandled> input) : base(input) { }
public static implicit operator BillingCommandResult<T>(T output) => new(output);
public static implicit operator BillingCommandResult<T>(BadRequest badRequest) => new(badRequest);
public static implicit operator BillingCommandResult<T>(Unhandled unhandled) => new(unhandled);
}
public static class BillingErrorTranslationKeys
{
// "The tax ID number you provided was invalid. Please try again or contact support."
public const string TaxIdInvalid = "taxIdInvalid";
// "Your location wasn't recognized. Please ensure your country and postal code are valid and try again."
public const string CustomerTaxLocationInvalid = "customerTaxLocationInvalid";
// "Something went wrong with your request. Please contact support."
public const string UnhandledError = "unhandledBillingError";
// "We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support."
public const string UnknownTaxIdType = "unknownTaxIdType";
}

View File

@ -1,4 +1,6 @@
namespace Bit.Core.Billing.Models; using Bit.Core.Billing.Tax.Models;
namespace Bit.Core.Billing.Models;
public record PaymentMethod( public record PaymentMethod(
long AccountCredit, long AccountCredit,

View File

@ -1,4 +1,6 @@
namespace Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Tax.Models;
namespace Bit.Core.Billing.Models.Sales;
#nullable enable #nullable enable

View File

@ -1,5 +1,6 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
namespace Bit.Core.Billing.Models.Sales; namespace Bit.Core.Billing.Models.Sales;
@ -26,12 +27,21 @@ public class OrganizationSale
public static OrganizationSale From( public static OrganizationSale From(
Organization organization, Organization organization,
OrganizationSignup signup) => new() OrganizationSignup signup)
{
var customerSetup = string.IsNullOrEmpty(organization.GatewayCustomerId) ? GetCustomerSetup(signup) : null;
var subscriptionSetup = GetSubscriptionSetup(signup);
subscriptionSetup.SkipTrial = signup.SkipTrial;
return new OrganizationSale
{ {
Organization = organization, Organization = organization,
CustomerSetup = string.IsNullOrEmpty(organization.GatewayCustomerId) ? GetCustomerSetup(signup) : null, CustomerSetup = customerSetup,
SubscriptionSetup = GetSubscriptionSetup(signup) SubscriptionSetup = subscriptionSetup
}; };
}
public static OrganizationSale From( public static OrganizationSale From(
Organization organization, Organization organization,

View File

@ -1,4 +1,5 @@
using Bit.Core.Entities; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;

View File

@ -24,6 +24,7 @@ public abstract record Plan
public bool Has2fa { get; protected init; } public bool Has2fa { get; protected init; }
public bool HasApi { get; protected init; } public bool HasApi { get; protected init; }
public bool HasSso { get; protected init; } public bool HasSso { get; protected init; }
public bool HasOrganizationDomains { get; protected init; }
public bool HasKeyConnector { get; protected init; } public bool HasKeyConnector { get; protected init; }
public bool HasScim { get; protected init; } public bool HasScim { get; protected init; }
public bool HasResetPassword { get; protected init; } public bool HasResetPassword { get; protected init; }

View File

@ -26,6 +26,7 @@ public record Enterprise2019Plan : Plan
Has2fa = true; Has2fa = true;
HasApi = true; HasApi = true;
HasSso = true; HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true; HasKeyConnector = true;
HasScim = true; HasScim = true;
HasResetPassword = true; HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record Enterprise2020Plan : Plan
Has2fa = true; Has2fa = true;
HasApi = true; HasApi = true;
HasSso = true; HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true; HasKeyConnector = true;
HasScim = true; HasScim = true;
HasResetPassword = true; HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record EnterprisePlan : Plan
Has2fa = true; Has2fa = true;
HasApi = true; HasApi = true;
HasSso = true; HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true; HasKeyConnector = true;
HasScim = true; HasScim = true;
HasResetPassword = true; HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record Enterprise2023Plan : Plan
Has2fa = true; Has2fa = true;
HasApi = true; HasApi = true;
HasSso = true; HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true; HasKeyConnector = true;
HasScim = true; HasScim = true;
HasResetPassword = true; HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record PlanAdapter : Plan
Has2fa = HasFeature("2fa"); Has2fa = HasFeature("2fa");
HasApi = HasFeature("api"); HasApi = HasFeature("api");
HasSso = HasFeature("sso"); HasSso = HasFeature("sso");
HasOrganizationDomains = HasFeature("organizationDomains");
HasKeyConnector = HasFeature("keyConnector"); HasKeyConnector = HasFeature("keyConnector");
HasScim = HasFeature("scim"); HasScim = HasFeature("scim");
HasResetPassword = HasFeature("resetPassword"); HasResetPassword = HasFeature("resetPassword");

View File

@ -1,6 +1,7 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Tax.Models;
namespace Bit.Core.Billing.Services; namespace Bit.Core.Billing.Services;

View File

@ -1,5 +1,6 @@
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities; using Bit.Core.Entities;
namespace Bit.Core.Billing.Services; namespace Bit.Core.Billing.Services;

View File

@ -4,6 +4,7 @@ using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
using Stripe; using Stripe;

View File

@ -1,4 +1,5 @@
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Stripe; using Stripe;

View File

@ -1,11 +1,13 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Caches; using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Repositories; using Bit.Core.Repositories;
@ -33,16 +35,15 @@ public class OrganizationBillingService(
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService, ISubscriberService subscriberService,
ITaxService taxService, ITaxService taxService) : IOrganizationBillingService
IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService
{ {
public async Task Finalize(OrganizationSale sale) public async Task Finalize(OrganizationSale sale)
{ {
var (organization, customerSetup, subscriptionSetup) = sale; var (organization, customerSetup, subscriptionSetup) = sale;
var customer = string.IsNullOrEmpty(organization.GatewayCustomerId) && customerSetup != null var customer = string.IsNullOrEmpty(organization.GatewayCustomerId) && customerSetup != null
? await CreateCustomerAsync(organization, customerSetup) ? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType)
: await subscriberService.GetCustomerOrThrow(organization, new CustomerGetOptions { Expand = ["tax", "tax_ids"] }); : await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup);
var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup); var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup);
@ -119,7 +120,8 @@ public class OrganizationBillingService(
subscription.CurrentPeriodEnd); subscription.CurrentPeriodEnd);
} }
public async Task UpdatePaymentMethod( public async Task
UpdatePaymentMethod(
Organization organization, Organization organization,
TokenizedPaymentSource tokenizedPaymentSource, TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation) TaxInformation taxInformation)
@ -149,8 +151,11 @@ public class OrganizationBillingService(
private async Task<Customer> CreateCustomerAsync( private async Task<Customer> CreateCustomerAsync(
Organization organization, Organization organization,
CustomerSetup customerSetup) CustomerSetup customerSetup,
PlanType? updatedPlanType = null)
{ {
var planType = updatedPlanType ?? organization.PlanType;
var displayName = organization.DisplayName(); var displayName = organization.DisplayName();
var customerCreateOptions = new CustomerCreateOptions var customerCreateOptions = new CustomerCreateOptions
@ -210,13 +215,24 @@ public class OrganizationBillingService(
City = customerSetup.TaxInformation.City, City = customerSetup.TaxInformation.City,
PostalCode = customerSetup.TaxInformation.PostalCode, PostalCode = customerSetup.TaxInformation.PostalCode,
State = customerSetup.TaxInformation.State, State = customerSetup.TaxInformation.State,
Country = customerSetup.TaxInformation.Country, Country = customerSetup.TaxInformation.Country
}; };
customerCreateOptions.Tax = new CustomerTaxOptions customerCreateOptions.Tax = new CustomerTaxOptions
{ {
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
}; };
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge &&
planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families &&
customerSetup.TaxInformation.Country != "US")
{
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId)) if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId))
{ {
var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country, var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country,
@ -397,21 +413,68 @@ public class OrganizationBillingService(
TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays
}; };
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge)
{ {
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType); subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
} }
else else if (customer.HasRecognizedTaxLocation())
{ {
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions(); subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
subscriptionCreateOptions.AutomaticTax.Enabled = customer.HasBillingLocation(); {
Enabled =
subscriptionSetup.PlanType.GetProductTier() == ProductTierType.Families ||
customer.Address.Country == "US" ||
customer.TaxIds.Any()
};
} }
return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
} }
private async Task<Customer> GetCustomerWhileEnsuringCorrectTaxExemptionAsync(
Organization organization,
SubscriptionSetup subscriptionSetup)
{
var customer = await subscriberService.GetCustomerOrThrow(organization,
new CustomerGetOptions { Expand = ["tax", "tax_ids"] });
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (!setNonUSBusinessUseToReverseCharge || subscriptionSetup.PlanType.GetProductTier() is
not (ProductTierType.Teams or
ProductTierType.TeamsStarter or
ProductTierType.Enterprise))
{
return customer;
}
List<string> expansions = ["tax", "tax_ids"];
customer = customer switch
{
{ Address.Country: not "US", TaxExempt: not StripeConstants.TaxExempt.Reverse } => await
stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions
{
Expand = expansions,
TaxExempt = StripeConstants.TaxExempt.Reverse
}),
{ Address.Country: "US", TaxExempt: StripeConstants.TaxExempt.Reverse } => await
stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions
{
Expand = expansions,
TaxExempt = StripeConstants.TaxExempt.None
}),
_ => customer
};
return customer;
}
private async Task<bool> IsEligibleForSelfHostAsync( private async Task<bool> IsEligibleForSelfHostAsync(
Organization organization) Organization organization)
{ {

View File

@ -2,7 +2,7 @@
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Services.Implementations.AutomaticTax; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
@ -10,7 +10,6 @@ using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Braintree; using Braintree;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Stripe; using Stripe;
using Customer = Stripe.Customer; using Customer = Stripe.Customer;
@ -22,20 +21,18 @@ using static Utilities;
public class PremiumUserBillingService( public class PremiumUserBillingService(
IBraintreeGateway braintreeGateway, IBraintreeGateway braintreeGateway,
IFeatureService featureService,
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
ILogger<PremiumUserBillingService> logger, ILogger<PremiumUserBillingService> logger,
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService, ISubscriberService subscriberService,
IUserRepository userRepository, IUserRepository userRepository) : IPremiumUserBillingService
[FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService
{ {
public async Task Credit(User user, decimal amount) public async Task Credit(User user, decimal amount)
{ {
var customer = await subscriberService.GetCustomer(user); var customer = await subscriberService.GetCustomer(user);
// Negative credit represents a balance and all Stripe denomination is in cents. // Negative credit represents a balance, and all Stripe denomination is in cents.
var credit = (long)(amount * -100); var credit = (long)(amount * -100);
if (customer == null) if (customer == null)
@ -182,7 +179,7 @@ public class PremiumUserBillingService(
City = customerSetup.TaxInformation.City, City = customerSetup.TaxInformation.City,
PostalCode = customerSetup.TaxInformation.PostalCode, PostalCode = customerSetup.TaxInformation.PostalCode,
State = customerSetup.TaxInformation.State, State = customerSetup.TaxInformation.State,
Country = customerSetup.TaxInformation.Country, Country = customerSetup.TaxInformation.Country
}, },
Description = user.Name, Description = user.Name,
Email = user.Email, Email = user.Email,
@ -322,6 +319,10 @@ public class PremiumUserBillingService(
var subscriptionCreateOptions = new SubscriptionCreateOptions var subscriptionCreateOptions = new SubscriptionCreateOptions
{ {
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
Customer = customer.Id, Customer = customer.Id,
Items = subscriptionItemOptionsList, Items = subscriptionItemOptionsList,
@ -335,18 +336,6 @@ public class PremiumUserBillingService(
OffSession = true OffSession = true
}; };
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported,
};
}
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
if (usingPayPal) if (usingPayPal)
@ -378,7 +367,7 @@ public class PremiumUserBillingService(
City = taxInformation.City, City = taxInformation.City,
PostalCode = taxInformation.PostalCode, PostalCode = taxInformation.PostalCode,
State = taxInformation.State, State = taxInformation.State,
Country = taxInformation.Country, Country = taxInformation.Country
}, },
Expand = ["tax"], Expand = ["tax"],
Tax = new CustomerTaxOptions Tax = new CustomerTaxOptions

View File

@ -1,7 +1,12 @@
using Bit.Core.Billing.Caches; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
@ -26,8 +31,7 @@ public class SubscriberService(
ILogger<SubscriberService> logger, ILogger<SubscriberService> logger,
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ITaxService taxService, ITaxService taxService) : ISubscriberService
IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService
{ {
public async Task CancelSubscription( public async Task CancelSubscription(
ISubscriber subscriber, ISubscriber subscriber,
@ -126,7 +130,7 @@ public class SubscriberService(
[subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion [subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion
}, },
Email = subscriber.BillingEmailAddress(), Email = subscriber.BillingEmailAddress(),
PaymentMethodNonce = paymentMethodNonce, PaymentMethodNonce = paymentMethodNonce
}); });
if (customerResult.IsSuccess()) if (customerResult.IsSuccess())
@ -480,7 +484,7 @@ public class SubscriberService(
var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First(); var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First();
// Find the customer's existing setup intents that should be cancelled. // Find the customer's existing setup intents that should be canceled.
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer) var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
.Where(si => .Where(si =>
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action"); si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
@ -517,7 +521,7 @@ public class SubscriberService(
await stripeAdapter.PaymentMethodAttachAsync(token, await stripeAdapter.PaymentMethodAttachAsync(token,
new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId }); new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId });
// Find the customer's existing setup intents that should be cancelled. // Find the customer's existing setup intents that should be canceled.
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer) var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
.Where(si => .Where(si =>
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action"); si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
@ -635,7 +639,8 @@ public class SubscriberService(
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
taxInformation.Country, taxInformation.Country,
taxInformation.TaxId); taxInformation.TaxId);
throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError");
throw new BadRequestException("billingTaxIdTypeInferenceError");
} }
} }
@ -652,53 +657,84 @@ public class SubscriberService(
logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.",
taxInformation.TaxId, taxInformation.TaxId,
taxInformation.Country); taxInformation.Country);
throw new Exceptions.BadRequestException("billingInvalidTaxIdError");
throw new BadRequestException("billingInvalidTaxIdError");
default: default:
logger.LogError(e, logger.LogError(e,
"Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.", "Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.",
taxInformation.TaxId, taxInformation.TaxId,
taxInformation.Country, taxInformation.Country,
customer.Id); customer.Id);
throw new Exceptions.BadRequestException("billingTaxIdCreationError");
throw new BadRequestException("billingTaxIdCreationError");
} }
} }
} }
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var subscription =
customer.Subscriptions.First(subscription => subscription.Id == subscriber.GatewaySubscriptionId);
var isBusinessUseSubscriber = subscriber switch
{ {
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) Organization organization => organization.PlanType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families,
Provider => true,
_ => false
};
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge && isBusinessUseSubscriber)
{
switch (customer)
{ {
var subscriptionGetOptions = new SubscriptionGetOptions case
{ {
Expand = ["customer.tax", "customer.tax_ids"] Address.Country: not "US",
}; TaxExempt: not StripeConstants.TaxExempt.Reverse
var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); }:
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); await stripeAdapter.CustomerUpdateAsync(customer.Id,
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription); break;
if (automaticTaxOptions?.AutomaticTax?.Enabled != null) case
{ {
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions); Address.Country: "US",
} TaxExempt: StripeConstants.TaxExempt.Reverse
}:
await stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.None });
break;
} }
}
else if (!subscription.AutomaticTax.Enabled)
{
if (SubscriberIsEligibleForAutomaticTax(subscriber, customer))
{ {
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, await stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
new SubscriptionUpdateOptions new SubscriptionUpdateOptions
{ {
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
}); });
} }
}
else
{
var automaticTaxShouldBeEnabled = subscriber switch
{
User => true,
Organization organization => organization.PlanType.GetProductTier() == ProductTierType.Families ||
customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false),
Provider => customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false),
_ => false
};
return; if (automaticTaxShouldBeEnabled && !subscription.AutomaticTax.Enabled)
{
bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer) await stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
=> !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) && new SubscriptionUpdateOptions
(localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) && {
localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
} }
} }

View File

@ -0,0 +1,147 @@
#nullable enable
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Services;
using Microsoft.Extensions.Logging;
using Stripe;
namespace Bit.Core.Billing.Tax.Commands;
public interface IPreviewTaxAmountCommand
{
Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters);
}
public class PreviewTaxAmountCommand(
ILogger<PreviewTaxAmountCommand> logger,
IPricingClient pricingClient,
IStripeAdapter stripeAdapter,
ITaxService taxService) : IPreviewTaxAmountCommand
{
public async Task<BillingCommandResult<decimal>> Run(OrganizationTrialParameters parameters)
{
var (planType, productType, taxInformation) = parameters;
var plan = await pricingClient.GetPlanOrThrow(planType);
var options = new InvoiceCreatePreviewOptions
{
Currency = "usd",
CustomerDetails = new InvoiceCustomerDetailsOptions
{
Address = new AddressOptions
{
Country = taxInformation.Country,
PostalCode = taxInformation.PostalCode
}
},
SubscriptionDetails = new InvoiceSubscriptionDetailsOptions
{
Items = [
new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.HasNonSeatBasedPasswordManagerPlan() ? plan.PasswordManager.StripePlanId : plan.PasswordManager.StripeSeatPlanId,
Quantity = 1
}
]
}
};
if (productType == ProductType.SecretsManager)
{
options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions
{
Price = plan.SecretsManager.StripeSeatPlanId,
Quantity = 1
});
options.Coupon = StripeConstants.CouponIDs.SecretsManagerStandalone;
}
if (!string.IsNullOrEmpty(taxInformation.TaxId))
{
var taxIdType = taxService.GetStripeTaxCode(
taxInformation.Country,
taxInformation.TaxId);
if (string.IsNullOrEmpty(taxIdType))
{
return BadRequest.UnknownTaxIdType;
}
options.CustomerDetails.TaxIds = [
new InvoiceCustomerDetailsTaxIdOptions
{
Type = taxIdType,
Value = taxInformation.TaxId
}
];
}
if (planType.GetProductTier() == ProductTierType.Families)
{
options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true };
}
else
{
options.AutomaticTax = new InvoiceAutomaticTaxOptions
{
Enabled = options.CustomerDetails.Address.Country == "US" ||
options.CustomerDetails.TaxIds is [_, ..]
};
}
try
{
var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options);
return Convert.ToDecimal(invoice.Tax) / 100;
}
catch (StripeException stripeException) when (stripeException.StripeError.Code ==
StripeConstants.ErrorCodes.CustomerTaxLocationInvalid)
{
return BadRequest.TaxLocationInvalid;
}
catch (StripeException stripeException) when (stripeException.StripeError.Code ==
StripeConstants.ErrorCodes.TaxIdInvalid)
{
return BadRequest.TaxIdNumberInvalid;
}
catch (StripeException stripeException)
{
logger.LogError(stripeException, "Stripe responded with an error during {Operation}. Code: {Code}", nameof(PreviewTaxAmountCommand), stripeException.StripeError.Code);
return new Unhandled();
}
}
}
#region Command Parameters
public record OrganizationTrialParameters
{
public required PlanType PlanType { get; set; }
public required ProductType ProductType { get; set; }
public required TaxInformationDTO TaxInformation { get; set; }
public void Deconstruct(
out PlanType planType,
out ProductType productType,
out TaxInformationDTO taxInformation)
{
planType = PlanType;
productType = ProductType;
taxInformation = TaxInformation;
}
public record TaxInformationDTO
{
public required string Country { get; set; }
public required string PostalCode { get; set; }
public string? TaxId { get; set; }
}
}
#endregion

View File

@ -1,6 +1,6 @@
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
namespace Bit.Core.Billing.Models; namespace Bit.Core.Billing.Tax.Models;
public class TaxIdType public class TaxIdType
{ {

View File

@ -1,6 +1,6 @@
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
namespace Bit.Core.Billing.Models; namespace Bit.Core.Billing.Tax.Models;
public record TaxInformation( public record TaxInformation(
string Country, string Country,

View File

@ -1,6 +1,6 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Core.Billing.Models.Api.Requests.Accounts; namespace Bit.Core.Billing.Tax.Requests;
public class PreviewIndividualInvoiceRequestBody public class PreviewIndividualInvoiceRequestBody
{ {

View File

@ -2,7 +2,7 @@
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Enums; using Bit.Core.Enums;
namespace Bit.Core.Billing.Models.Api.Requests.Organizations; namespace Bit.Core.Billing.Tax.Requests;
public class PreviewOrganizationInvoiceRequestBody public class PreviewOrganizationInvoiceRequestBody
{ {

View File

@ -1,6 +1,6 @@
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Core.Billing.Models.Api.Requests; namespace Bit.Core.Billing.Tax.Requests;
public class TaxInformationRequestModel public class TaxInformationRequestModel
{ {

View File

@ -1,4 +1,4 @@
namespace Bit.Core.Billing.Models.Api.Responses; namespace Bit.Core.Billing.Tax.Responses;
public record PreviewInvoiceResponseModel( public record PreviewInvoiceResponseModel(
decimal EffectiveTaxRate, decimal EffectiveTaxRate,

View File

@ -1,6 +1,6 @@
using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Contracts;
namespace Bit.Core.Billing.Services; namespace Bit.Core.Billing.Tax.Services;
/// <summary> /// <summary>
/// Responsible for defining the correct automatic tax strategy for either personal use of business use. /// Responsible for defining the correct automatic tax strategy for either personal use of business use.

View File

@ -1,7 +1,7 @@
#nullable enable #nullable enable
using Stripe; using Stripe;
namespace Bit.Core.Billing.Services; namespace Bit.Core.Billing.Tax.Services;
public interface IAutomaticTaxStrategy public interface IAutomaticTaxStrategy
{ {

View File

@ -1,4 +1,4 @@
namespace Bit.Core.Billing.Services; namespace Bit.Core.Billing.Tax.Services;
public interface ITaxService public interface ITaxService
{ {

View File

@ -5,7 +5,7 @@ using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Services; using Bit.Core.Services;
namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; namespace Bit.Core.Billing.Tax.Services.Implementations;
public class AutomaticTaxFactory( public class AutomaticTaxFactory(
IFeatureService featureService, IFeatureService featureService,

View File

@ -3,7 +3,7 @@ using Bit.Core.Billing.Extensions;
using Bit.Core.Services; using Bit.Core.Services;
using Stripe; using Stripe;
namespace Bit.Core.Billing.Services.Implementations.AutomaticTax; namespace Bit.Core.Billing.Tax.Services.Implementations;
public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : IAutomaticTaxStrategy public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : IAutomaticTaxStrategy
{ {
@ -76,7 +76,7 @@ public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : I
private bool ShouldBeEnabled(Customer customer) private bool ShouldBeEnabled(Customer customer)
{ {
if (!customer.HasTaxLocationVerified()) if (!customer.HasRecognizedTaxLocation())
{ {
return false; return false;
} }

Some files were not shown because too many files have changed in this diff Show More