1
0
mirror of https://github.com/bitwarden/server.git synced 2025-05-23 04:21:05 -05:00

Merge branch 'main' into ac/pm-21612/unified-mariadb-cannot-edit-an-unconfirmed-user

This commit is contained in:
Thomas Rittson 2025-05-21 11:42:24 +10:00 committed by GitHub
commit 4e727f3660
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
100 changed files with 5243 additions and 2365 deletions

View File

@ -636,7 +636,9 @@ jobs:
setup-ephemeral-environment: setup-ephemeral-environment:
name: Setup Ephemeral Environment name: Setup Ephemeral Environment
needs: build-docker needs:
- build-artifacts
- build-docker
if: | if: |
needs.build-artifacts.outputs.has_secrets == 'true' needs.build-artifacts.outputs.has_secrets == 'true'
&& github.event_name == 'pull_request' && github.event_name == 'pull_request'

View File

@ -3,7 +3,7 @@
<PropertyGroup> <PropertyGroup>
<TargetFramework>net8.0</TargetFramework> <TargetFramework>net8.0</TargetFramework>
<Version>2025.5.0</Version> <Version>2025.5.1</Version>
<RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace> <RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace>
<ImplicitUsings>enable</ImplicitUsings> <ImplicitUsings>enable</ImplicitUsings>

View File

@ -8,13 +8,10 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Billing.Tax.Services.Implementations;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Microsoft.Extensions.DependencyInjection;
using Stripe; using Stripe;
namespace Bit.Commercial.Core.AdminConsole.Providers; namespace Bit.Commercial.Core.AdminConsole.Providers;
@ -24,7 +21,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
private readonly IEventService _eventService; private readonly IEventService _eventService;
private readonly IMailService _mailService; private readonly IMailService _mailService;
private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationService _organizationService;
private readonly IProviderOrganizationRepository _providerOrganizationRepository; private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IStripeAdapter _stripeAdapter; private readonly IStripeAdapter _stripeAdapter;
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
@ -32,26 +28,22 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
private readonly ISubscriberService _subscriberService; private readonly ISubscriberService _subscriberService;
private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery; private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
private readonly IAutomaticTaxStrategy _automaticTaxStrategy;
public RemoveOrganizationFromProviderCommand( public RemoveOrganizationFromProviderCommand(
IEventService eventService, IEventService eventService,
IMailService mailService, IMailService mailService,
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
IOrganizationService organizationService,
IProviderOrganizationRepository providerOrganizationRepository, IProviderOrganizationRepository providerOrganizationRepository,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
IFeatureService featureService, IFeatureService featureService,
IProviderBillingService providerBillingService, IProviderBillingService providerBillingService,
ISubscriberService subscriberService, ISubscriberService subscriberService,
IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery,
IPricingClient pricingClient, IPricingClient pricingClient)
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
{ {
_eventService = eventService; _eventService = eventService;
_mailService = mailService; _mailService = mailService;
_organizationRepository = organizationRepository; _organizationRepository = organizationRepository;
_organizationService = organizationService;
_providerOrganizationRepository = providerOrganizationRepository; _providerOrganizationRepository = providerOrganizationRepository;
_stripeAdapter = stripeAdapter; _stripeAdapter = stripeAdapter;
_featureService = featureService; _featureService = featureService;
@ -59,7 +51,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
_subscriberService = subscriberService; _subscriberService = subscriberService;
_hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery; _hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery;
_pricingClient = pricingClient; _pricingClient = pricingClient;
_automaticTaxStrategy = automaticTaxStrategy;
} }
public async Task RemoveOrganizationFromProvider( public async Task RemoveOrganizationFromProvider(
@ -77,7 +68,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync( if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync(
providerOrganization.OrganizationId, providerOrganization.OrganizationId,
Array.Empty<Guid>(), [],
includeProvider: false)) includeProvider: false))
{ {
throw new BadRequestException("Organization must have at least one confirmed owner."); throw new BadRequestException("Organization must have at least one confirmed owner.");
@ -102,7 +93,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
/// <summary> /// <summary>
/// When a client organization is unlinked from a provider, we have to check if they're Stripe-enabled /// When a client organization is unlinked from a provider, we have to check if they're Stripe-enabled
/// and, if they are, we remove their MSP discount and set their Subscription to `send_invoice`. This is because /// and, if they are, we remove their MSP discount and set their Subscription to `send_invoice`. This is because
/// the provider's payment method will be removed from their Stripe customer causing ensuing charges to fail. Lastly, /// the provider's payment method will be removed from their Stripe customer, causing ensuing charges to fail. Lastly,
/// we email the organization owners letting them know they need to add a new payment method. /// we email the organization owners letting them know they need to add a new payment method.
/// </summary> /// </summary>
private async Task ResetOrganizationBillingAsync( private async Task ResetOrganizationBillingAsync(
@ -142,15 +133,18 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }] Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }]
}; };
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var setNonUSBusinessUseToReverseCharge = _featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge)
{ {
_automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
} }
else else if (customer.HasRecognizedTaxLocation())
{ {
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{ {
Enabled = true Enabled = customer.Address.Country == "US" ||
customer.TaxIds.Any()
}; };
} }
@ -187,7 +181,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
await _mailService.SendProviderUpdatePaymentMethod( await _mailService.SendProviderUpdatePaymentMethod(
organization.Id, organization.Id,
organization.Name, organization.Name,
provider.Name, provider.Name!,
organizationOwnerEmails); organizationOwnerEmails);
} }
} }

View File

@ -5,6 +5,7 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Provider;
using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Models.Business.Tokenables;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
@ -53,6 +54,7 @@ public class ProviderService : IProviderService
private readonly IApplicationCacheService _applicationCacheService; private readonly IApplicationCacheService _applicationCacheService;
private readonly IProviderBillingService _providerBillingService; private readonly IProviderBillingService _providerBillingService;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
private readonly IProviderClientOrganizationSignUpCommand _providerClientOrganizationSignUpCommand;
public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository,
@ -61,7 +63,8 @@ public class ProviderService : IProviderService
IOrganizationRepository organizationRepository, GlobalSettings globalSettings, IOrganizationRepository organizationRepository, GlobalSettings globalSettings,
ICurrentContext currentContext, IStripeAdapter stripeAdapter, IFeatureService featureService, ICurrentContext currentContext, IStripeAdapter stripeAdapter, IFeatureService featureService,
IDataProtectorTokenFactory<ProviderDeleteTokenable> providerDeleteTokenDataFactory, IDataProtectorTokenFactory<ProviderDeleteTokenable> providerDeleteTokenDataFactory,
IApplicationCacheService applicationCacheService, IProviderBillingService providerBillingService, IPricingClient pricingClient) IApplicationCacheService applicationCacheService, IProviderBillingService providerBillingService, IPricingClient pricingClient,
IProviderClientOrganizationSignUpCommand providerClientOrganizationSignUpCommand)
{ {
_providerRepository = providerRepository; _providerRepository = providerRepository;
_providerUserRepository = providerUserRepository; _providerUserRepository = providerUserRepository;
@ -81,6 +84,7 @@ public class ProviderService : IProviderService
_applicationCacheService = applicationCacheService; _applicationCacheService = applicationCacheService;
_providerBillingService = providerBillingService; _providerBillingService = providerBillingService;
_pricingClient = pricingClient; _pricingClient = pricingClient;
_providerClientOrganizationSignUpCommand = providerClientOrganizationSignUpCommand;
} }
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null)
@ -560,12 +564,12 @@ public class ProviderService : IProviderService
ThrowOnInvalidPlanType(provider.Type, organizationSignup.Plan); ThrowOnInvalidPlanType(provider.Type, organizationSignup.Plan);
var (organization, _, defaultCollection) = await _organizationService.SignupClientAsync(organizationSignup); var signUpResponse = await _providerClientOrganizationSignUpCommand.SignUpClientOrganizationAsync(organizationSignup);
var providerOrganization = new ProviderOrganization var providerOrganization = new ProviderOrganization
{ {
ProviderId = providerId, ProviderId = providerId,
OrganizationId = organization.Id, OrganizationId = signUpResponse.Organization.Id,
Key = organizationSignup.OwnerKey, Key = organizationSignup.OwnerKey,
}; };
@ -574,12 +578,12 @@ public class ProviderService : IProviderService
// Give the owner Can Manage access over the default collection // Give the owner Can Manage access over the default collection
// The orgUser is not available when the org is created so we have to do it here as part of the invite // The orgUser is not available when the org is created so we have to do it here as part of the invite
var defaultOwnerAccess = defaultCollection != null var defaultOwnerAccess = signUpResponse.DefaultCollection != null
? ?
[ [
new CollectionAccessSelection new CollectionAccessSelection
{ {
Id = defaultCollection.Id, Id = signUpResponse.DefaultCollection.Id,
HidePasswords = false, HidePasswords = false,
ReadOnly = false, ReadOnly = false,
Manage = true Manage = true
@ -587,7 +591,7 @@ public class ProviderService : IProviderService
] ]
: Array.Empty<CollectionAccessSelection>(); : Array.Empty<CollectionAccessSelection>();
await _organizationService.InviteUsersAsync(organization.Id, user.Id, systemUser: null, await _organizationService.InviteUsersAsync(signUpResponse.Organization.Id, user.Id, systemUser: null,
new (OrganizationUserInvite, string)[] new (OrganizationUserInvite, string)[]
{ {
( (

View File

@ -67,6 +67,7 @@ public class BusinessUnitConverter(
organization.MaxStorageGb = updatedPlan.PasswordManager.BaseStorageGb; organization.MaxStorageGb = updatedPlan.PasswordManager.BaseStorageGb;
organization.UsePolicies = updatedPlan.HasPolicies; organization.UsePolicies = updatedPlan.HasPolicies;
organization.UseSso = updatedPlan.HasSso; organization.UseSso = updatedPlan.HasSso;
organization.UseOrganizationDomains = updatedPlan.HasOrganizationDomains;
organization.UseGroups = updatedPlan.HasGroups; organization.UseGroups = updatedPlan.HasGroups;
organization.UseEvents = updatedPlan.HasEvents; organization.UseEvents = updatedPlan.HasEvents;
organization.UseDirectory = updatedPlan.HasDirectory; organization.UseDirectory = updatedPlan.HasDirectory;

View File

@ -18,7 +18,6 @@ using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services; using Bit.Core.Billing.Tax.Services;
using Bit.Core.Billing.Tax.Services.Implementations;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
@ -27,7 +26,6 @@ using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Braintree; using Braintree;
using CsvHelper; using CsvHelper;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Stripe; using Stripe;
@ -52,8 +50,7 @@ public class ProviderBillingService(
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService, ISubscriberService subscriberService,
ITaxService taxService, ITaxService taxService)
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
: IProviderBillingService : IProviderBillingService
{ {
public async Task AddExistingOrganization( public async Task AddExistingOrganization(
@ -99,6 +96,7 @@ public class ProviderBillingService(
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.UsePolicies = plan.HasPolicies; organization.UsePolicies = plan.HasPolicies;
organization.UseSso = plan.HasSso; organization.UseSso = plan.HasSso;
organization.UseOrganizationDomains = plan.HasOrganizationDomains;
organization.UseGroups = plan.HasGroups; organization.UseGroups = plan.HasGroups;
organization.UseEvents = plan.HasEvents; organization.UseEvents = plan.HasEvents;
organization.UseDirectory = plan.HasDirectory; organization.UseDirectory = plan.HasDirectory;
@ -127,7 +125,7 @@ public class ProviderBillingService(
/* /*
* We have to scale the provider's seats before the ProviderOrganization * We have to scale the provider's seats before the ProviderOrganization
* row is inserted so the added organization's seats don't get double counted. * row is inserted so the added organization's seats don't get double-counted.
*/ */
await ScaleSeats(provider, organization.PlanType, organization.Seats!.Value); await ScaleSeats(provider, organization.PlanType, organization.Seats!.Value);
@ -235,7 +233,7 @@ public class ProviderBillingService(
var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions
{ {
Expand = ["tax_ids"] Expand = ["tax", "tax_ids"]
}); });
var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); var providerTaxId = providerCustomer.TaxIds.FirstOrDefault();
@ -283,6 +281,13 @@ public class ProviderBillingService(
] ]
}; };
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge && providerCustomer.Address is not { Country: "US" })
{
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
organization.GatewayCustomerId = customer.Id; organization.GatewayCustomerId = customer.Id;
@ -519,6 +524,13 @@ public class ProviderBillingService(
} }
}; };
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge && taxInfo.BillingAddressCountry != "US")
{
options.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber)) if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber))
{ {
var taxIdType = taxService.GetStripeTaxCode( var taxIdType = taxService.GetStripeTaxCode(
@ -530,6 +542,7 @@ public class ProviderBillingService(
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
taxInfo.BillingAddressCountry, taxInfo.BillingAddressCountry,
taxInfo.TaxIdNumber); taxInfo.TaxIdNumber);
throw new BadRequestException("billingTaxIdTypeInferenceError"); throw new BadRequestException("billingTaxIdTypeInferenceError");
} }
@ -694,6 +707,13 @@ public class ProviderBillingService(
customer.Metadata.ContainsKey(BraintreeCustomerIdKey) || customer.Metadata.ContainsKey(BraintreeCustomerIdKey) ||
setupIntent.IsUnverifiedBankAccount()); setupIntent.IsUnverifiedBankAccount());
int? trialPeriodDays = provider.Type switch
{
ProviderType.Msp when usePaymentMethod => 14,
ProviderType.BusinessUnit when usePaymentMethod => 4,
_ => null
};
var subscriptionCreateOptions = new SubscriptionCreateOptions var subscriptionCreateOptions = new SubscriptionCreateOptions
{ {
CollectionMethod = usePaymentMethod ? CollectionMethod = usePaymentMethod ?
@ -707,17 +727,24 @@ public class ProviderBillingService(
}, },
OffSession = true, OffSession = true,
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
TrialPeriodDays = usePaymentMethod ? 14 : null TrialPeriodDays = trialPeriodDays
}; };
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var setNonUSBusinessUseToReverseCharge =
{ featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
} if (setNonUSBusinessUseToReverseCharge)
else
{ {
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
} }
else if (customer.HasRecognizedTaxLocation())
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = customer.Address.Country == "US" ||
customer.TaxIds.Any()
};
}
try try
{ {

View File

@ -1,4 +1,5 @@
using Bit.Commercial.Core.AdminConsole.Providers; using Bit.Commercial.Core.AdminConsole.Providers;
using Bit.Core;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
@ -8,7 +9,6 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Repositories; using Bit.Core.Repositories;
@ -224,31 +224,115 @@ public class RemoveOrganizationFromProviderCommandTests
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>(); var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(options =>
options.Description == string.Empty &&
options.Email == organization.BillingEmail &&
options.Expand[0] == "tax" &&
options.Expand[1] == "tax_ids")).Returns(new Customer
{
Id = "customer_id",
Address = new Address
{
Country = "US"
}
});
stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(new Subscription stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(new Subscription
{ {
Id = "subscription_id" Id = "subscription_id"
}); });
sutProvider.GetDependency<IAutomaticTaxStrategy>() await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization);
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options => await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == organization.GatewayCustomerId && options.Customer == organization.GatewayCustomerId &&
options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice &&
options.DaysUntilDue == 30 && options.DaysUntilDue == 30 &&
options.Metadata["organizationId"] == organization.Id.ToString() && options.AutomaticTax.Enabled == true &&
options.OffSession == true && options.Metadata["organizationId"] == organization.Id.ToString() &&
options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && options.OffSession == true &&
options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId && options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
options.Items.First().Quantity == organization.Seats) options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId &&
, Arg.Any<Customer>())) options.Items.First().Quantity == organization.Seats));
.Do(x =>
await sutProvider.GetDependency<IProviderBillingService>().Received(1)
.ScaleSeats(provider, organization.PlanType, -organization.Seats ?? 0);
await organizationRepository.Received(1).ReplaceAsync(Arg.Is<Organization>(
org =>
org.BillingEmail == "a@example.com" &&
org.GatewaySubscriptionId == "subscription_id" &&
org.Status == OrganizationStatusType.Created));
await sutProvider.GetDependency<IProviderOrganizationRepository>().Received(1)
.DeleteAsync(providerOrganization);
await sutProvider.GetDependency<IEventService>().Received(1)
.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed);
await sutProvider.GetDependency<IMailService>().Received(1)
.SendProviderUpdatePaymentMethod(
organization.Id,
organization.Name,
provider.Name,
Arg.Is<IEnumerable<string>>(emails => emails.FirstOrDefault() == "a@example.com"));
}
[Theory, BitAutoData]
public async Task RemoveOrganizationFromProvider_OrganizationStripeEnabled_ConsolidatedBilling_ReverseCharge_MakesCorrectInvocations(
Provider provider,
ProviderOrganization providerOrganization,
Organization organization,
SutProvider<RemoveOrganizationFromProviderCommand> sutProvider)
{
provider.Status = ProviderStatusType.Billable;
providerOrganization.ProviderId = provider.Id;
organization.Status = OrganizationStatusType.Managed;
organization.PlanType = PlanType.TeamsMonthly;
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan);
sutProvider.GetDependency<IHasConfirmedOwnersExceptQuery>().HasConfirmedOwnersExceptAsync(
providerOrganization.OrganizationId,
[],
includeProvider: false)
.Returns(true);
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns([
"a@example.com",
"b@example.com"
]);
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(options =>
options.Description == string.Empty &&
options.Email == organization.BillingEmail &&
options.Expand[0] == "tax" &&
options.Expand[1] == "tax_ids")).Returns(new Customer
{ {
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions Id = "customer_id",
Address = new Address
{ {
Enabled = true Country = "US"
}; }
}); });
stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(new Subscription
{
Id = "subscription_id"
});
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization);
await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options => await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options =>

View File

@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Provider;
using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Models.Business.Tokenables;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
@ -717,8 +718,8 @@ public class ProviderServiceTests
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider); sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>(); var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
sutProvider.GetDependency<IOrganizationService>().SignupClientAsync(organizationSignup) sutProvider.GetDependency<IProviderClientOrganizationSignUpCommand>().SignUpClientOrganizationAsync(organizationSignup)
.Returns((organization, null as OrganizationUser, new Collection())); .Returns(new ProviderClientOrganizationSignUpResponse(organization, new Collection()));
var providerOrganization = var providerOrganization =
await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user); await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user);
@ -755,8 +756,8 @@ public class ProviderServiceTests
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>(); var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
sutProvider.GetDependency<IOrganizationService>().SignupClientAsync(organizationSignup) sutProvider.GetDependency<IProviderClientOrganizationSignUpCommand>().SignUpClientOrganizationAsync(organizationSignup)
.Returns((organization, null as OrganizationUser, new Collection())); .Returns(new ProviderClientOrganizationSignUpResponse(organization, new Collection()));
await Assert.ThrowsAsync<BadRequestException>(() => await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user)); sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user));
@ -782,8 +783,8 @@ public class ProviderServiceTests
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>(); var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
sutProvider.GetDependency<IOrganizationService>().SignupClientAsync(organizationSignup) sutProvider.GetDependency<IProviderClientOrganizationSignUpCommand>().SignUpClientOrganizationAsync(organizationSignup)
.Returns((organization, null as OrganizationUser, new Collection())); .Returns(new ProviderClientOrganizationSignUpResponse(organization, new Collection()));
var providerOrganization = await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user); var providerOrganization = await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user);
@ -821,8 +822,8 @@ public class ProviderServiceTests
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider); sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>(); var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
sutProvider.GetDependency<IOrganizationService>().SignupClientAsync(organizationSignup) sutProvider.GetDependency<IProviderClientOrganizationSignUpCommand>().SignUpClientOrganizationAsync(organizationSignup)
.Returns((organization, null as OrganizationUser, defaultCollection)); .Returns(new ProviderClientOrganizationSignUpResponse(organization, defaultCollection));
var providerOrganization = var providerOrganization =
await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user); await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user);

View File

@ -262,7 +262,7 @@ public class ProviderBillingServiceTests
}; };
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>( sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>(
options => options.Expand.FirstOrDefault() == "tax_ids")) options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids")))
.Returns(providerCustomer); .Returns(providerCustomer);
sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri
@ -312,6 +312,91 @@ public class ProviderBillingServiceTests
org => org.GatewayCustomerId == "customer_id")); org => org.GatewayCustomerId == "customer_id"));
} }
[Theory, BitAutoData]
public async Task CreateCustomer_ForClientOrg_ReverseCharge_Succeeds(
Provider provider,
Organization organization,
SutProvider<ProviderBillingService> sutProvider)
{
organization.GatewayCustomerId = null;
organization.Name = "Name";
organization.BusinessName = "BusinessName";
var providerCustomer = new Customer
{
Address = new Address
{
Country = "CA",
PostalCode = "12345",
Line1 = "123 Main St.",
Line2 = "Unit 4",
City = "Fake Town",
State = "Fake State"
},
TaxIds = new StripeList<TaxId>
{
Data =
[
new TaxId { Type = "TYPE", Value = "VALUE" }
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>(
options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids")))
.Returns(providerCustomer);
sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri
.Returns(new Bit.Core.Settings.GlobalSettings.BaseServiceUriSettings(new Bit.Core.Settings.GlobalSettings())
{
CloudRegion = "US"
});
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
sutProvider.GetDependency<IStripeAdapter>().CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(
options =>
options.Address.Country == providerCustomer.Address.Country &&
options.Address.PostalCode == providerCustomer.Address.PostalCode &&
options.Address.Line1 == providerCustomer.Address.Line1 &&
options.Address.Line2 == providerCustomer.Address.Line2 &&
options.Address.City == providerCustomer.Address.City &&
options.Address.State == providerCustomer.Address.State &&
options.Name == organization.DisplayName() &&
options.Description == $"{provider.Name} Client Organization" &&
options.Email == provider.BillingEmail &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" &&
options.Metadata["region"] == "US" &&
options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type &&
options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value &&
options.TaxExempt == StripeConstants.TaxExempt.Reverse))
.Returns(new Customer { Id = "customer_id" });
await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization);
await sutProvider.GetDependency<IStripeAdapter>().Received(1).CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(
options =>
options.Address.Country == providerCustomer.Address.Country &&
options.Address.PostalCode == providerCustomer.Address.PostalCode &&
options.Address.Line1 == providerCustomer.Address.Line1 &&
options.Address.Line2 == providerCustomer.Address.Line2 &&
options.Address.City == providerCustomer.Address.City &&
options.Address.State == providerCustomer.Address.State &&
options.Name == organization.DisplayName() &&
options.Description == $"{provider.Name} Client Organization" &&
options.Email == provider.BillingEmail &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" &&
options.Metadata["region"] == "US" &&
options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type &&
options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value));
await sutProvider.GetDependency<IOrganizationRepository>().Received(1).ReplaceAsync(Arg.Is<Organization>(
org => org.GatewayCustomerId == "customer_id"));
}
#endregion #endregion
#region GenerateClientInvoiceReport #region GenerateClientInvoiceReport
@ -1182,6 +1267,62 @@ public class ProviderBillingServiceTests
Assert.Equivalent(expected, actual); Assert.Equivalent(expected, actual);
} }
[Theory, BitAutoData]
public async Task SetupCustomer_WithCard_ReverseCharge_Success(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var expected = new Customer
{
Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
};
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token");
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Email == provider.BillingEmail &&
o.PaymentMethod == tokenizedPaymentSource.Token &&
o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber &&
o.TaxExempt == StripeConstants.TaxExempt.Reverse))
.Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource);
Assert.Equivalent(expected, actual);
}
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid( public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid(
SutProvider<ProviderBillingService> sutProvider, SutProvider<ProviderBillingService> sutProvider,
@ -1307,7 +1448,7 @@ public class ProviderBillingServiceTests
.Returns(new Customer .Returns(new Customer
{ {
Id = "customer_id", Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } Address = new Address { Country = "US" }
}); });
var providerPlans = new List<ProviderPlan> var providerPlans = new List<ProviderPlan>
@ -1359,7 +1500,7 @@ public class ProviderBillingServiceTests
var customer = new Customer var customer = new Customer
{ {
Id = "customer_id", Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } Address = new Address { Country = "US" }
}; };
sutProvider.GetDependency<ISubscriberService>() sutProvider.GetDependency<ISubscriberService>()
.GetCustomerOrThrow( .GetCustomerOrThrow(
@ -1399,19 +1540,6 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>( sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
sub => sub =>
sub.AutomaticTax.Enabled == true && sub.AutomaticTax.Enabled == true &&
@ -1443,11 +1571,11 @@ public class ProviderBillingServiceTests
var customer = new Customer var customer = new Customer
{ {
Id = "customer_id", Id = "customer_id",
Address = new Address { Country = "US" },
InvoiceSettings = new CustomerInvoiceSettings InvoiceSettings = new CustomerInvoiceSettings
{ {
DefaultPaymentMethodId = "pm_123" DefaultPaymentMethodId = "pm_123"
}, }
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}; };
sutProvider.GetDependency<ISubscriberService>() sutProvider.GetDependency<ISubscriberService>()
@ -1488,19 +1616,6 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IFeatureService>() sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
@ -1536,9 +1651,9 @@ public class ProviderBillingServiceTests
var customer = new Customer var customer = new Customer
{ {
Id = "customer_id", Id = "customer_id",
Address = new Address { Country = "US" },
InvoiceSettings = new CustomerInvoiceSettings(), InvoiceSettings = new CustomerInvoiceSettings(),
Metadata = new Dictionary<string, string>(), Metadata = new Dictionary<string, string>()
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}; };
sutProvider.GetDependency<ISubscriberService>() sutProvider.GetDependency<ISubscriberService>()
@ -1579,19 +1694,6 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IFeatureService>() sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
@ -1646,12 +1748,15 @@ public class ProviderBillingServiceTests
var customer = new Customer var customer = new Customer
{ {
Id = "customer_id", Id = "customer_id",
Address = new Address
{
Country = "US"
},
InvoiceSettings = new CustomerInvoiceSettings(), InvoiceSettings = new CustomerInvoiceSettings(),
Metadata = new Dictionary<string, string> Metadata = new Dictionary<string, string>
{ {
["btCustomerId"] = "braintree_customer_id" ["btCustomerId"] = "braintree_customer_id"
}, }
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}; };
sutProvider.GetDependency<ISubscriberService>() sutProvider.GetDependency<ISubscriberService>()
@ -1692,22 +1797,92 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>() sutProvider.GetDependency<IFeatureService>()
.When(x => x.SetCreateOptions( .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id") sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
, Arg.Is<Customer>(p => p == customer))) sub =>
.Do(x => sub.AutomaticTax.Enabled == true &&
sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically &&
sub.Customer == "customer_id" &&
sub.DaysUntilDue == null &&
sub.Items.Count == 2 &&
sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams &&
sub.Items.ElementAt(0).Quantity == 100 &&
sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise &&
sub.Items.ElementAt(1).Quantity == 100 &&
sub.Metadata["providerId"] == provider.Id.ToString() &&
sub.OffSession == true &&
sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
sub.TrialPeriodDays == 14)).Returns(expected);
var actual = await sutProvider.Sut.SetupSubscription(provider);
Assert.Equivalent(expected, actual);
}
[Theory, BitAutoData]
public async Task SetupSubscription_ReverseCharge_Succeeds(
SutProvider<ProviderBillingService> sutProvider,
Provider provider)
{
provider.Type = ProviderType.Msp;
provider.GatewaySubscriptionId = null;
var customer = new Customer
{
Id = "customer_id",
Address = new Address { Country = "CA" },
InvoiceSettings = new CustomerInvoiceSettings
{ {
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions DefaultPaymentMethodId = "pm_123"
{ }
Enabled = true };
};
}); sutProvider.GetDependency<ISubscriberService>()
.GetCustomerOrThrow(
provider,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer);
var providerPlans = new List<ProviderPlan>
{
new()
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.TeamsMonthly,
SeatMinimum = 100,
PurchasedSeats = 0,
AllocatedSeats = 0
},
new()
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.EnterpriseMonthly,
SeatMinimum = 100,
PurchasedSeats = 0,
AllocatedSeats = 0
}
};
foreach (var plan in providerPlans)
{
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(plan.PlanType)
.Returns(StaticStore.GetPlan(plan.PlanType));
}
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id)
.Returns(providerPlans);
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IFeatureService>() sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>( sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
sub => sub =>
sub.AutomaticTax.Enabled == true && sub.AutomaticTax.Enabled == true &&

View File

@ -11,6 +11,7 @@ MAILCATCHER_PORT=1080
# Alternative databases # Alternative databases
POSTGRES_PASSWORD=SET_A_PASSWORD_HERE_123 POSTGRES_PASSWORD=SET_A_PASSWORD_HERE_123
MYSQL_ROOT_PASSWORD=SET_A_PASSWORD_HERE_123 MYSQL_ROOT_PASSWORD=SET_A_PASSWORD_HERE_123
MARIADB_ROOT_PASSWORD=SET_A_PASSWORD_HERE_123
# IdP configuration # IdP configuration
# Complete using the values from the Manage SSO page in the web vault # Complete using the values from the Manage SSO page in the web vault

View File

@ -70,6 +70,20 @@ services:
profiles: profiles:
- mysql - mysql
mariadb:
image: mariadb:10
ports:
- 4306:3306
environment:
MARIADB_USER: maria
MARIADB_PASSWORD: ${MARIADB_ROOT_PASSWORD}
MARIADB_DATABASE: vault_dev
MARIADB_RANDOM_ROOT_PASSWORD: "true"
volumes:
- mariadb_dev_data:/var/lib/mysql
profiles:
- mariadb
idp: idp:
image: kenchan0130/simplesamlphp:1.19.8 image: kenchan0130/simplesamlphp:1.19.8
container_name: idp container_name: idp

View File

@ -5,6 +5,7 @@ param(
[switch]$all, [switch]$all,
[switch]$postgres, [switch]$postgres,
[switch]$mysql, [switch]$mysql,
[switch]$mariadb,
[switch]$mssql, [switch]$mssql,
[switch]$sqlite, [switch]$sqlite,
[switch]$selfhost, [switch]$selfhost,
@ -15,11 +16,15 @@ param(
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
$currentDir = Get-Location $currentDir = Get-Location
if (!$all -and !$postgres -and !$mysql -and !$sqlite) { function Get-IsEFDatabase {
return $postgres -or $mysql -or $mariadb -or $sqlite;
}
if (!$all -and !$(Get-IsEFDatabase)) {
$mssql = $true; $mssql = $true;
} }
if ($all -or $postgres -or $mysql -or $sqlite) { if ($all -or $(Get-IsEFDatabase)) {
dotnet ef *> $null dotnet ef *> $null
if ($LASTEXITCODE -ne 0) { if ($LASTEXITCODE -ne 0) {
Write-Host "Entity Framework Core tools were not found in the dotnet global tools. Attempting to install" Write-Host "Entity Framework Core tools were not found in the dotnet global tools. Attempting to install"
@ -60,9 +65,12 @@ if ($all -or $mssql) {
} }
Foreach ($item in @( Foreach ($item in @(
@($mysql, "MySQL", "MySqlMigrations", "mySql", 2),
@($postgres, "PostgreSQL", "PostgresMigrations", "postgreSql", 0), @($postgres, "PostgreSQL", "PostgresMigrations", "postgreSql", 0),
@($sqlite, "SQLite", "SqliteMigrations", "sqlite", 1) @($sqlite, "SQLite", "SqliteMigrations", "sqlite", 1),
@($mysql, "MySQL", "MySqlMigrations", "mySql", 2),
# MariaDB shares the MySQL connection string in the server config so they are mutually exclusive in that context.
# However they can still be run independently for integration tests.
@($mariadb, "MariaDB", "MySqlMigrations", "mySql", 3)
)) { )) {
if (!$item[0] -and !$all) { if (!$item[0] -and !$all) {
continue continue

View File

@ -40,8 +40,6 @@ export function authenticate(
payload["deviceName"] = "chrome"; payload["deviceName"] = "chrome";
payload["username"] = username; payload["username"] = username;
payload["password"] = password; payload["password"] = password;
params.headers["Auth-Email"] = encoding.b64encode(username);
} else { } else {
payload["scope"] = "api.organization"; payload["scope"] = "api.organization";
payload["grant_type"] = "client_credentials"; payload["grant_type"] = "client_credentials";

View File

@ -69,7 +69,7 @@
document.getElementById('@(nameof(Model.UseGroups))').checked = plan.hasGroups; document.getElementById('@(nameof(Model.UseGroups))').checked = plan.hasGroups;
document.getElementById('@(nameof(Model.UsePolicies))').checked = plan.hasPolicies; document.getElementById('@(nameof(Model.UsePolicies))').checked = plan.hasPolicies;
document.getElementById('@(nameof(Model.UseSso))').checked = plan.hasSso; document.getElementById('@(nameof(Model.UseSso))').checked = plan.hasSso;
document.getElementById('@(nameof(Model.UseOrganizationDomains))').checked = hasOrganizationDomains; document.getElementById('@(nameof(Model.UseOrganizationDomains))').checked = plan.hasOrganizationDomains;
document.getElementById('@(nameof(Model.UseScim))').checked = plan.hasScim; document.getElementById('@(nameof(Model.UseScim))').checked = plan.hasScim;
document.getElementById('@(nameof(Model.UseDirectory))').checked = plan.hasDirectory; document.getElementById('@(nameof(Model.UseDirectory))').checked = plan.hasDirectory;
document.getElementById('@(nameof(Model.UseEvents))').checked = plan.hasEvents; document.getElementById('@(nameof(Model.UseEvents))').checked = plan.hasEvents;

View File

@ -2,7 +2,7 @@
using Bit.Core; using Bit.Core;
using Bit.Core.Jobs; using Bit.Core.Jobs;
using Bit.Core.Tools.Repositories; using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.Services; using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Quartz; using Quartz;
namespace Bit.Admin.Tools.Jobs; namespace Bit.Admin.Tools.Jobs;
@ -32,10 +32,10 @@ public class DeleteSendsJob : BaseJob
} }
using (var scope = _serviceProvider.CreateScope()) using (var scope = _serviceProvider.CreateScope())
{ {
var sendService = scope.ServiceProvider.GetRequiredService<ISendService>(); var nonAnonymousSendCommand = scope.ServiceProvider.GetRequiredService<INonAnonymousSendCommand>();
foreach (var send in sends) foreach (var send in sends)
{ {
await sendService.DeleteSendAsync(send); await nonAnonymousSendCommand.DeleteSendAsync(send);
} }
} }
} }

View File

@ -2,13 +2,11 @@
using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.AdminConsole.Models.Response.Organizations; using Bit.Api.AdminConsole.Models.Response.Organizations;
using Bit.Api.Models.Response; using Bit.Api.Models.Response;
using Bit.Core;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
@ -137,7 +135,6 @@ public class OrganizationDomainController : Controller
[AllowAnonymous] [AllowAnonymous]
[HttpPost("domain/sso/verified")] [HttpPost("domain/sso/verified")]
[RequireFeature(FeatureFlagKeys.VerifiedSsoDomainEndpoint)]
public async Task<VerifiedOrganizationDomainSsoDetailsResponseModel> GetVerifiedOrgDomainSsoDetailsAsync( public async Task<VerifiedOrganizationDomainSsoDetailsResponseModel> GetVerifiedOrgDomainSsoDetailsAsync(
[FromBody] OrganizationDomainSsoDetailsRequestModel model) [FromBody] OrganizationDomainSsoDetailsRequestModel model)
{ {

View File

@ -292,15 +292,17 @@ public class OrganizationBillingController(
sale.Organization.PlanType = plan.Type; sale.Organization.PlanType = plan.Type;
sale.Organization.Plan = plan.Name; sale.Organization.Plan = plan.Name;
sale.SubscriptionSetup.SkipTrial = true; sale.SubscriptionSetup.SkipTrial = true;
await organizationBillingService.Finalize(sale);
if (organizationSignup.PaymentMethodType == null || string.IsNullOrEmpty(organizationSignup.PaymentToken))
{
return Error.BadRequest("A payment method is required to restart the subscription.");
}
var org = await organizationRepository.GetByIdAsync(organizationId); var org = await organizationRepository.GetByIdAsync(organizationId);
Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine."); Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine.");
if (organizationSignup.PaymentMethodType != null) var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken);
{ var taxInformation = TaxInformation.From(organizationSignup.TaxInfo);
var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken); await organizationBillingService.UpdatePaymentMethod(org, paymentSource, taxInformation);
var taxInformation = TaxInformation.From(organizationSignup.TaxInfo); await organizationBillingService.Finalize(sale);
await organizationBillingService.UpdatePaymentMethod(org, paymentSource, taxInformation);
}
return TypedResults.Ok(); return TypedResults.Ok();
} }

View File

@ -207,7 +207,7 @@ public class OrganizationSponsorshipsController : Controller
[HttpDelete("{sponsoringOrganizationId}")] [HttpDelete("{sponsoringOrganizationId}")]
[HttpPost("{sponsoringOrganizationId}/delete")] [HttpPost("{sponsoringOrganizationId}/delete")]
[SelfHosted(NotSelfHostedOnly = true)] [SelfHosted(NotSelfHostedOnly = true)]
public async Task RevokeSponsorship(Guid sponsoringOrganizationId, [FromQuery] bool isAdminInitiated = false) public async Task RevokeSponsorship(Guid sponsoringOrganizationId)
{ {
var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrganizationId, _currentContext.UserId ?? default); var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrganizationId, _currentContext.UserId ?? default);
@ -217,11 +217,25 @@ public class OrganizationSponsorshipsController : Controller
} }
var existingOrgSponsorship = await _organizationSponsorshipRepository var existingOrgSponsorship = await _organizationSponsorshipRepository
.GetBySponsoringOrganizationUserIdAsync(orgUser.Id, isAdminInitiated); .GetBySponsoringOrganizationUserIdAsync(orgUser.Id);
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
} }
[Authorize("Application")]
[HttpDelete("{sponsoringOrgId}/{sponsoredFriendlyName}/revoke")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task AdminInitiatedRevokeSponsorshipAsync(Guid sponsoringOrgId, string sponsoredFriendlyName)
{
var sponsorships = await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrgId);
var existingOrgSponsorship = sponsorships.FirstOrDefault(s => s.FriendlyName != null && s.FriendlyName.Equals(sponsoredFriendlyName, StringComparison.OrdinalIgnoreCase));
if (existingOrgSponsorship == null)
{
throw new BadRequestException("The specified sponsored organization could not be found under the given sponsoring organization.");
}
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
}
[Authorize("Application")] [Authorize("Application")]
[HttpDelete("sponsored/{sponsoredOrgId}")] [HttpDelete("sponsored/{sponsoredOrgId}")]
[HttpPost("sponsored/{sponsoredOrgId}/remove")] [HttpPost("sponsored/{sponsoredOrgId}/remove")]

View File

@ -109,28 +109,6 @@ public class OrganizationsController(
return license; return license;
} }
[HttpPost("{id:guid}/payment")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task PostPayment(Guid id, [FromBody] PaymentRequestModel model)
{
if (!await currentContext.EditPaymentMethods(id))
{
throw new NotFoundException();
}
await organizationService.ReplacePaymentMethodAsync(id, model.PaymentToken,
model.PaymentMethodType.Value, new TaxInfo
{
BillingAddressLine1 = model.Line1,
BillingAddressLine2 = model.Line2,
BillingAddressState = model.State,
BillingAddressCity = model.City,
BillingAddressPostalCode = model.PostalCode,
BillingAddressCountry = model.Country,
TaxIdNumber = model.TaxId,
});
}
[HttpPost("{id:guid}/upgrade")] [HttpPost("{id:guid}/upgrade")]
[SelfHosted(NotSelfHostedOnly = true)] [SelfHosted(NotSelfHostedOnly = true)]
public async Task<PaymentResponseModel> PostUpgrade(Guid id, [FromBody] OrganizationUpgradeRequestModel model) public async Task<PaymentResponseModel> PostUpgrade(Guid id, [FromBody] OrganizationUpgradeRequestModel model)

View File

@ -1,6 +1,10 @@
using Bit.Api.Models.Request.Organizations; using Bit.Api.AdminConsole.Authorization.Requirements;
using Bit.Api.Models.Request.Organizations;
using Bit.Api.Models.Response;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Api.Response.OrganizationSponsorships;
using Bit.Core.Models.Data.Organizations.OrganizationSponsorships;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
@ -22,6 +26,7 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand; private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand;
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
private readonly IAuthorizationService _authorizationService;
public SelfHostedOrganizationSponsorshipsController( public SelfHostedOrganizationSponsorshipsController(
ICreateSponsorshipCommand offerSponsorshipCommand, ICreateSponsorshipCommand offerSponsorshipCommand,
@ -30,7 +35,8 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationSponsorshipRepository organizationSponsorshipRepository,
IOrganizationUserRepository organizationUserRepository, IOrganizationUserRepository organizationUserRepository,
ICurrentContext currentContext, ICurrentContext currentContext,
IFeatureService featureService IFeatureService featureService,
IAuthorizationService authorizationService
) )
{ {
_offerSponsorshipCommand = offerSponsorshipCommand; _offerSponsorshipCommand = offerSponsorshipCommand;
@ -40,6 +46,7 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
_organizationUserRepository = organizationUserRepository; _organizationUserRepository = organizationUserRepository;
_currentContext = currentContext; _currentContext = currentContext;
_featureService = featureService; _featureService = featureService;
_authorizationService = authorizationService;
} }
[HttpPost("{sponsoringOrgId}/families-for-enterprise")] [HttpPost("{sponsoringOrgId}/families-for-enterprise")]
@ -84,4 +91,41 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
} }
[HttpDelete("{sponsoringOrgId}/{sponsoredFriendlyName}/revoke")]
public async Task AdminInitiatedRevokeSponsorshipAsync(Guid sponsoringOrgId, string sponsoredFriendlyName)
{
var sponsorships = await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrgId);
var existingOrgSponsorship = sponsorships.FirstOrDefault(s => s.FriendlyName != null && s.FriendlyName.Equals(sponsoredFriendlyName, StringComparison.OrdinalIgnoreCase));
if (existingOrgSponsorship == null)
{
throw new BadRequestException("The specified sponsored organization could not be found under the given sponsoring organization.");
}
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
}
[Authorize("Application")]
[HttpGet("{orgId}/sponsored")]
public async Task<ListResponseModel<OrganizationSponsorshipInvitesResponseModel>> GetSponsoredOrganizations(Guid orgId)
{
var sponsoringOrg = await _organizationRepository.GetByIdAsync(orgId);
if (sponsoringOrg == null)
{
throw new NotFoundException();
}
var authorizationResult = await _authorizationService.AuthorizeAsync(User, orgId, new ManageUsersRequirement());
if (!authorizationResult.Succeeded)
{
throw new UnauthorizedAccessException();
}
var sponsorships = await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(orgId);
return new ListResponseModel<OrganizationSponsorshipInvitesResponseModel>(
sponsorships
.Where(s => s.IsAdminInitiated)
.Select(s => new OrganizationSponsorshipInvitesResponseModel(new OrganizationSponsorshipData(s)))
);
}
} }

View File

@ -12,17 +12,17 @@ namespace Bit.Api.KeyManagement.Validators;
/// </summary> /// </summary>
public class SendRotationValidator : IRotationValidator<IEnumerable<SendWithIdRequestModel>, IReadOnlyList<Send>> public class SendRotationValidator : IRotationValidator<IEnumerable<SendWithIdRequestModel>, IReadOnlyList<Send>>
{ {
private readonly ISendService _sendService; private readonly ISendAuthorizationService _sendAuthorizationService;
private readonly ISendRepository _sendRepository; private readonly ISendRepository _sendRepository;
/// <summary> /// <summary>
/// Instantiates a new <see cref="SendRotationValidator"/> /// Instantiates a new <see cref="SendRotationValidator"/>
/// </summary> /// </summary>
/// <param name="sendService">Enables conversion of <see cref="SendWithIdRequestModel"/> to <see cref="Send"/></param> /// <param name="sendAuthorizationService">Enables conversion of <see cref="SendWithIdRequestModel"/> to <see cref="Send"/></param>
/// <param name="sendRepository">Retrieves all user <see cref="Send"/>s</param> /// <param name="sendRepository">Retrieves all user <see cref="Send"/>s</param>
public SendRotationValidator(ISendService sendService, ISendRepository sendRepository) public SendRotationValidator(ISendAuthorizationService sendAuthorizationService, ISendRepository sendRepository)
{ {
_sendService = sendService; _sendAuthorizationService = sendAuthorizationService;
_sendRepository = sendRepository; _sendRepository = sendRepository;
} }
@ -44,7 +44,7 @@ public class SendRotationValidator : IRotationValidator<IEnumerable<SendWithIdRe
throw new BadRequestException("All existing sends must be included in the rotation."); throw new BadRequestException("All existing sends must be included in the rotation.");
} }
result.Add(send.ToSend(existing, _sendService)); result.Add(send.ToSend(existing, _sendAuthorizationService));
} }
return result; return result;

View File

@ -35,6 +35,7 @@ using Bit.Core.Services;
using Bit.Core.Tools.ImportFeatures; using Bit.Core.Tools.ImportFeatures;
using Bit.Core.Tools.ReportFeatures; using Bit.Core.Tools.ReportFeatures;
using Bit.Core.Auth.Models.Api.Request; using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Tools.SendFeatures;
#if !OSS #if !OSS
using Bit.Commercial.Core.SecretsManager; using Bit.Commercial.Core.SecretsManager;
@ -186,6 +187,7 @@ public class Startup
services.AddPhishingDomainServices(globalSettings); services.AddPhishingDomainServices(globalSettings);
services.AddBillingQueries(); services.AddBillingQueries();
services.AddSendServices();
// Authorization Handlers // Authorization Handlers
services.AddAuthorizationHandlers(); services.AddAuthorizationHandlers();

View File

@ -12,6 +12,8 @@ using Bit.Core.Settings;
using Bit.Core.Tools.Enums; using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data; using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories; using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services; using Bit.Core.Tools.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
@ -25,8 +27,10 @@ public class SendsController : Controller
{ {
private readonly ISendRepository _sendRepository; private readonly ISendRepository _sendRepository;
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly ISendService _sendService; private readonly ISendAuthorizationService _sendAuthorizationService;
private readonly ISendFileStorageService _sendFileStorageService; private readonly ISendFileStorageService _sendFileStorageService;
private readonly IAnonymousSendCommand _anonymousSendCommand;
private readonly INonAnonymousSendCommand _nonAnonymousSendCommand;
private readonly ILogger<SendsController> _logger; private readonly ILogger<SendsController> _logger;
private readonly GlobalSettings _globalSettings; private readonly GlobalSettings _globalSettings;
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
@ -34,7 +38,9 @@ public class SendsController : Controller
public SendsController( public SendsController(
ISendRepository sendRepository, ISendRepository sendRepository,
IUserService userService, IUserService userService,
ISendService sendService, ISendAuthorizationService sendAuthorizationService,
IAnonymousSendCommand anonymousSendCommand,
INonAnonymousSendCommand nonAnonymousSendCommand,
ISendFileStorageService sendFileStorageService, ISendFileStorageService sendFileStorageService,
ILogger<SendsController> logger, ILogger<SendsController> logger,
GlobalSettings globalSettings, GlobalSettings globalSettings,
@ -42,13 +48,16 @@ public class SendsController : Controller
{ {
_sendRepository = sendRepository; _sendRepository = sendRepository;
_userService = userService; _userService = userService;
_sendService = sendService; _sendAuthorizationService = sendAuthorizationService;
_anonymousSendCommand = anonymousSendCommand;
_nonAnonymousSendCommand = nonAnonymousSendCommand;
_sendFileStorageService = sendFileStorageService; _sendFileStorageService = sendFileStorageService;
_logger = logger; _logger = logger;
_globalSettings = globalSettings; _globalSettings = globalSettings;
_currentContext = currentContext; _currentContext = currentContext;
} }
#region Anonymous endpoints
[AllowAnonymous] [AllowAnonymous]
[HttpPost("access/{id}")] [HttpPost("access/{id}")]
public async Task<IActionResult> Access(string id, [FromBody] SendAccessRequestModel model) public async Task<IActionResult> Access(string id, [FromBody] SendAccessRequestModel model)
@ -61,18 +70,19 @@ public class SendsController : Controller
//} //}
var guid = new Guid(CoreHelpers.Base64UrlDecode(id)); var guid = new Guid(CoreHelpers.Base64UrlDecode(id));
var (send, passwordRequired, passwordInvalid) = var send = await _sendRepository.GetByIdAsync(guid);
await _sendService.AccessAsync(guid, model.Password); SendAccessResult sendAuthResult =
if (passwordRequired) await _sendAuthorizationService.AccessAsync(send, model.Password);
if (sendAuthResult.Equals(SendAccessResult.PasswordRequired))
{ {
return new UnauthorizedResult(); return new UnauthorizedResult();
} }
if (passwordInvalid) if (sendAuthResult.Equals(SendAccessResult.PasswordInvalid))
{ {
await Task.Delay(2000); await Task.Delay(2000);
throw new BadRequestException("Invalid password."); throw new BadRequestException("Invalid password.");
} }
if (send == null) if (sendAuthResult.Equals(SendAccessResult.Denied))
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
@ -106,19 +116,19 @@ public class SendsController : Controller
throw new BadRequestException("Could not locate send"); throw new BadRequestException("Could not locate send");
} }
var (url, passwordRequired, passwordInvalid) = await _sendService.GetSendFileDownloadUrlAsync(send, fileId, var (url, result) = await _anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId,
model.Password); model.Password);
if (passwordRequired) if (result.Equals(SendAccessResult.PasswordRequired))
{ {
return new UnauthorizedResult(); return new UnauthorizedResult();
} }
if (passwordInvalid) if (result.Equals(SendAccessResult.PasswordInvalid))
{ {
await Task.Delay(2000); await Task.Delay(2000);
throw new BadRequestException("Invalid password."); throw new BadRequestException("Invalid password.");
} }
if (send == null) if (result.Equals(SendAccessResult.Denied))
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
@ -130,6 +140,45 @@ public class SendsController : Controller
}); });
} }
[AllowAnonymous]
[HttpPost("file/validate/azure")]
public async Task<ObjectResult> AzureValidateFile()
{
return await ApiHelpers.HandleAzureEvents(Request, new Dictionary<string, Func<EventGridEvent, Task>>
{
{
"Microsoft.Storage.BlobCreated", async (eventGridEvent) =>
{
try
{
var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1];
var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName);
var send = await _sendRepository.GetByIdAsync(new Guid(sendId));
if (send == null)
{
if (_sendFileStorageService is AzureSendFileStorageService azureSendFileStorageService)
{
await azureSendFileStorageService.DeleteBlobAsync(blobName);
}
return;
}
await _nonAnonymousSendCommand.ConfirmFileSize(send);
}
catch (Exception e)
{
_logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}");
return;
}
}
}
});
}
#endregion
#region Non-anonymous endpoints
[HttpGet("{id}")] [HttpGet("{id}")]
public async Task<SendResponseModel> Get(string id) public async Task<SendResponseModel> Get(string id)
{ {
@ -157,8 +206,8 @@ public class SendsController : Controller
{ {
model.ValidateCreation(); model.ValidateCreation();
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var send = model.ToSend(userId, _sendService); var send = model.ToSend(userId, _sendAuthorizationService);
await _sendService.SaveSendAsync(send); await _nonAnonymousSendCommand.SaveSendAsync(send);
return new SendResponseModel(send, _globalSettings); return new SendResponseModel(send, _globalSettings);
} }
@ -175,15 +224,15 @@ public class SendsController : Controller
throw new BadRequestException("Invalid content. File size hint is required."); throw new BadRequestException("Invalid content. File size hint is required.");
} }
if (model.FileLength.Value > SendService.MAX_FILE_SIZE) if (model.FileLength.Value > Constants.FileSize501mb)
{ {
throw new BadRequestException($"Max file size is {SendService.MAX_FILE_SIZE_READABLE}."); throw new BadRequestException($"Max file size is {SendFileSettingHelper.MAX_FILE_SIZE_READABLE}.");
} }
model.ValidateCreation(); model.ValidateCreation();
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var (send, data) = model.ToSend(userId, model.File.FileName, _sendService); var (send, data) = model.ToSend(userId, model.File.FileName, _sendAuthorizationService);
var uploadUrl = await _sendService.SaveFileSendAsync(send, data, model.FileLength.Value); var uploadUrl = await _nonAnonymousSendCommand.SaveFileSendAsync(send, data, model.FileLength.Value);
return new SendFileUploadDataResponseModel return new SendFileUploadDataResponseModel
{ {
Url = uploadUrl, Url = uploadUrl,
@ -230,41 +279,7 @@ public class SendsController : Controller
var send = await _sendRepository.GetByIdAsync(new Guid(id)); var send = await _sendRepository.GetByIdAsync(new Guid(id));
await Request.GetFileAsync(async (stream) => await Request.GetFileAsync(async (stream) =>
{ {
await _sendService.UploadFileToExistingSendAsync(stream, send); await _nonAnonymousSendCommand.UploadFileToExistingSendAsync(stream, send);
});
}
[AllowAnonymous]
[HttpPost("file/validate/azure")]
public async Task<ObjectResult> AzureValidateFile()
{
return await ApiHelpers.HandleAzureEvents(Request, new Dictionary<string, Func<EventGridEvent, Task>>
{
{
"Microsoft.Storage.BlobCreated", async (eventGridEvent) =>
{
try
{
var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1];
var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName);
var send = await _sendRepository.GetByIdAsync(new Guid(sendId));
if (send == null)
{
if (_sendFileStorageService is AzureSendFileStorageService azureSendFileStorageService)
{
await azureSendFileStorageService.DeleteBlobAsync(blobName);
}
return;
}
await _sendService.ValidateSendFile(send);
}
catch (Exception e)
{
_logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}");
return;
}
}
}
}); });
} }
@ -279,7 +294,7 @@ public class SendsController : Controller
throw new NotFoundException(); throw new NotFoundException();
} }
await _sendService.SaveSendAsync(model.ToSend(send, _sendService)); await _nonAnonymousSendCommand.SaveSendAsync(model.ToSend(send, _sendAuthorizationService));
return new SendResponseModel(send, _globalSettings); return new SendResponseModel(send, _globalSettings);
} }
@ -294,7 +309,7 @@ public class SendsController : Controller
} }
send.Password = null; send.Password = null;
await _sendService.SaveSendAsync(send); await _nonAnonymousSendCommand.SaveSendAsync(send);
return new SendResponseModel(send, _globalSettings); return new SendResponseModel(send, _globalSettings);
} }
@ -308,6 +323,8 @@ public class SendsController : Controller
throw new NotFoundException(); throw new NotFoundException();
} }
await _sendService.DeleteSendAsync(send); await _nonAnonymousSendCommand.DeleteSendAsync(send);
} }
#endregion
} }

View File

@ -36,31 +36,31 @@ public class SendRequestModel
public bool? Disabled { get; set; } public bool? Disabled { get; set; }
public bool? HideEmail { get; set; } public bool? HideEmail { get; set; }
public Send ToSend(Guid userId, ISendService sendService) public Send ToSend(Guid userId, ISendAuthorizationService sendAuthorizationService)
{ {
var send = new Send var send = new Send
{ {
Type = Type, Type = Type,
UserId = (Guid?)userId UserId = (Guid?)userId
}; };
ToSend(send, sendService); ToSend(send, sendAuthorizationService);
return send; return send;
} }
public (Send, SendFileData) ToSend(Guid userId, string fileName, ISendService sendService) public (Send, SendFileData) ToSend(Guid userId, string fileName, ISendAuthorizationService sendAuthorizationService)
{ {
var send = ToSendBase(new Send var send = ToSendBase(new Send
{ {
Type = Type, Type = Type,
UserId = (Guid?)userId UserId = (Guid?)userId
}, sendService); }, sendAuthorizationService);
var data = new SendFileData(Name, Notes, fileName); var data = new SendFileData(Name, Notes, fileName);
return (send, data); return (send, data);
} }
public Send ToSend(Send existingSend, ISendService sendService) public Send ToSend(Send existingSend, ISendAuthorizationService sendAuthorizationService)
{ {
existingSend = ToSendBase(existingSend, sendService); existingSend = ToSendBase(existingSend, sendAuthorizationService);
switch (existingSend.Type) switch (existingSend.Type)
{ {
case SendType.File: case SendType.File:
@ -125,7 +125,7 @@ public class SendRequestModel
} }
} }
private Send ToSendBase(Send existingSend, ISendService sendService) private Send ToSendBase(Send existingSend, ISendAuthorizationService authorizationService)
{ {
existingSend.Key = Key; existingSend.Key = Key;
existingSend.ExpirationDate = ExpirationDate; existingSend.ExpirationDate = ExpirationDate;
@ -133,7 +133,7 @@ public class SendRequestModel
existingSend.MaxAccessCount = MaxAccessCount; existingSend.MaxAccessCount = MaxAccessCount;
if (!string.IsNullOrWhiteSpace(Password)) if (!string.IsNullOrWhiteSpace(Password))
{ {
existingSend.Password = sendService.HashPassword(Password); existingSend.Password = authorizationService.HashPassword(Password);
} }
existingSend.Disabled = Disabled.GetValueOrDefault(); existingSend.Disabled = Disabled.GetValueOrDefault();
existingSend.HideEmail = HideEmail.GetValueOrDefault(); existingSend.HideEmail = HideEmail.GetValueOrDefault();

View File

@ -1,11 +1,11 @@
using Bit.Core; using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
@ -25,8 +25,7 @@ public class UpcomingInvoiceHandler(
IStripeEventService stripeEventService, IStripeEventService stripeEventService,
IStripeEventUtilityService stripeEventUtilityService, IStripeEventUtilityService stripeEventUtilityService,
IUserRepository userRepository, IUserRepository userRepository,
IValidateSponsorshipCommand validateSponsorshipCommand, IValidateSponsorshipCommand validateSponsorshipCommand)
IAutomaticTaxFactory automaticTaxFactory)
: IUpcomingInvoiceHandler : IUpcomingInvoiceHandler
{ {
public async Task HandleAsync(Event parsedEvent) public async Task HandleAsync(Event parsedEvent)
@ -46,6 +45,8 @@ public class UpcomingInvoiceHandler(
var (organizationId, userId, providerId) = stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata); var (organizationId, userId, providerId) = stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata);
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (organizationId.HasValue) if (organizationId.HasValue)
{ {
var organization = await organizationRepository.GetByIdAsync(organizationId.Value); var organization = await organizationRepository.GetByIdAsync(organizationId.Value);
@ -55,7 +56,7 @@ public class UpcomingInvoiceHandler(
return; return;
} }
await TryEnableAutomaticTaxAsync(subscription); await AlignOrganizationTaxConcernsAsync(organization, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge);
var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); var plan = await pricingClient.GetPlanOrThrow(organization.PlanType);
@ -100,7 +101,25 @@ public class UpcomingInvoiceHandler(
return; return;
} }
await TryEnableAutomaticTaxAsync(subscription); if (!subscription.AutomaticTax.Enabled && subscription.Customer.HasRecognizedTaxLocation())
{
try
{
await stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set user's ({UserID}) subscription to automatic tax while processing event with ID {EventID}",
user.Id,
parsedEvent.Id);
}
}
if (user.Premium) if (user.Premium)
{ {
@ -116,7 +135,7 @@ public class UpcomingInvoiceHandler(
return; return;
} }
await TryEnableAutomaticTaxAsync(subscription); await AlignProviderTaxConcernsAsync(provider, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge);
await SendUpcomingInvoiceEmailsAsync(new List<string> { provider.BillingEmail }, invoice); await SendUpcomingInvoiceEmailsAsync(new List<string> { provider.BillingEmail }, invoice);
} }
@ -139,50 +158,123 @@ public class UpcomingInvoiceHandler(
} }
} }
private async Task TryEnableAutomaticTaxAsync(Subscription subscription) private async Task AlignOrganizationTaxConcernsAsync(
Organization organization,
Subscription subscription,
string eventId,
bool setNonUSBusinessUseToReverseCharge)
{ {
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var nonUSBusinessUse =
{ organization.PlanType.GetProductTier() != ProductTierType.Families &&
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id)); subscription.Customer.Address.Country != "US";
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
if (updateOptions == null) bool setAutomaticTaxToEnabled;
if (setNonUSBusinessUseToReverseCharge)
{
if (nonUSBusinessUse && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse)
{ {
return; try
{
await stripeFacade.UpdateCustomer(subscription.CustomerId,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}",
organization.Id,
eventId);
}
} }
await stripeFacade.UpdateSubscription(subscription.Id, updateOptions); setAutomaticTaxToEnabled = true;
return;
} }
else
if (subscription.AutomaticTax.Enabled ||
!subscription.Customer.HasBillingLocation() ||
await IsNonTaxableNonUSBusinessUseSubscription(subscription))
{ {
return; setAutomaticTaxToEnabled =
subscription.Customer.HasRecognizedTaxLocation() &&
(subscription.Customer.Address.Country == "US" ||
(nonUSBusinessUse && subscription.Customer.TaxIds.Any()));
} }
await stripeFacade.UpdateSubscription(subscription.Id, if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled)
new SubscriptionUpdateOptions {
try
{ {
DefaultTaxRates = [], await stripeFacade.UpdateSubscription(subscription.Id,
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } new SubscriptionUpdateOptions
}); {
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set organization's ({OrganizationID}) subscription to automatic tax while processing event with ID {EventID}",
organization.Id,
eventId);
}
}
}
return; private async Task AlignProviderTaxConcernsAsync(
Provider provider,
Subscription subscription,
string eventId,
bool setNonUSBusinessUseToReverseCharge)
{
bool setAutomaticTaxToEnabled;
async Task<bool> IsNonTaxableNonUSBusinessUseSubscription(Subscription localSubscription) if (setNonUSBusinessUseToReverseCharge)
{ {
var familyPriceIds = (await Task.WhenAll( if (subscription.Customer.Address.Country != "US" && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse)
pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), {
pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually))) try
.Select(plan => plan.PasswordManager.StripePlanId); {
await stripeFacade.UpdateCustomer(subscription.CustomerId,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}",
provider.Id,
eventId);
}
}
return localSubscription.Customer.Address.Country != "US" && setAutomaticTaxToEnabled = true;
localSubscription.Metadata.ContainsKey(StripeConstants.MetadataKeys.OrganizationId) && }
!localSubscription.Items.Select(item => item.Price.Id).Intersect(familyPriceIds).Any() && else
!localSubscription.Customer.TaxIds.Any(); {
setAutomaticTaxToEnabled =
subscription.Customer.HasRecognizedTaxLocation() &&
(subscription.Customer.Address.Country == "US" ||
subscription.Customer.TaxIds.Any());
}
if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled)
{
try
{
await stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set provider's ({ProviderID}) subscription to automatic tax while processing event with ID {EventID}",
provider.Id,
eventId);
}
} }
} }
} }

View File

@ -325,5 +325,6 @@ public class Organization : ITableObject<Guid>, IStorableSubscriber, IRevisable,
SmServiceAccounts = license.SmServiceAccounts; SmServiceAccounts = license.SmServiceAccounts;
UseRiskInsights = license.UseRiskInsights; UseRiskInsights = license.UseRiskInsights;
UseOrganizationDomains = license.UseOrganizationDomains; UseOrganizationDomains = license.UseOrganizationDomains;
UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies;
} }
} }

View File

@ -150,6 +150,7 @@ public class SelfHostedOrganizationDetails : Organization
AllowAdminAccessToAllCollectionItems = AllowAdminAccessToAllCollectionItems, AllowAdminAccessToAllCollectionItems = AllowAdminAccessToAllCollectionItems,
Status = Status, Status = Status,
UseRiskInsights = UseRiskInsights, UseRiskInsights = UseRiskInsights,
UseAdminSponsoredFamilies = UseAdminSponsoredFamilies,
}; };
} }
} }

View File

@ -24,9 +24,7 @@ public class GetOrganizationUsersClaimedStatusQuery : IGetOrganizationUsersClaim
// Users can only be claimed by an Organization that is enabled and can have organization domains // Users can only be claimed by an Organization that is enabled and can have organization domains
var organizationAbility = await _applicationCacheService.GetOrganizationAbilityAsync(organizationId); var organizationAbility = await _applicationCacheService.GetOrganizationAbilityAsync(organizationId);
// TODO: Replace "UseSso" with a new organization ability like "UseOrganizationDomains" (PM-11622). if (organizationAbility is { Enabled: true, UseOrganizationDomains: true })
// Verified domains were tied to SSO, so we currently check the "UseSso" organization ability.
if (organizationAbility is { Enabled: true, UseSso: true })
{ {
// Get all organization users with claimed domains by the organization // Get all organization users with claimed domains by the organization
var organizationUsersWithClaimedDomain = await _organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organizationId); var organizationUsersWithClaimedDomain = await _organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organizationId);

View File

@ -0,0 +1,187 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Pricing;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Models.StaticStore;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
public record ProviderClientOrganizationSignUpResponse(
Organization Organization,
Collection DefaultCollection);
public interface IProviderClientOrganizationSignUpCommand
{
/// <summary>
/// Sign up a new client organization for a provider.
/// </summary>
/// <param name="signup">The signup information.</param>
/// <returns>A tuple containing the new organization and its default collection.</returns>
Task<ProviderClientOrganizationSignUpResponse> SignUpClientOrganizationAsync(OrganizationSignup signup);
}
public class ProviderClientOrganizationSignUpCommand : IProviderClientOrganizationSignUpCommand
{
public const string PlanNullErrorMessage = "Password Manager Plan was null.";
public const string PlanDisabledErrorMessage = "Password Manager Plan is disabled.";
public const string AdditionalSeatsNegativeErrorMessage = "You can't subtract Password Manager seats!";
private readonly ICurrentContext _currentContext;
private readonly IPricingClient _pricingClient;
private readonly IReferenceEventService _referenceEventService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository;
private readonly IApplicationCacheService _applicationCacheService;
private readonly ICollectionRepository _collectionRepository;
public ProviderClientOrganizationSignUpCommand(
ICurrentContext currentContext,
IPricingClient pricingClient,
IReferenceEventService referenceEventService,
IOrganizationRepository organizationRepository,
IOrganizationApiKeyRepository organizationApiKeyRepository,
IApplicationCacheService applicationCacheService,
ICollectionRepository collectionRepository)
{
_currentContext = currentContext;
_pricingClient = pricingClient;
_referenceEventService = referenceEventService;
_organizationRepository = organizationRepository;
_organizationApiKeyRepository = organizationApiKeyRepository;
_applicationCacheService = applicationCacheService;
_collectionRepository = collectionRepository;
}
public async Task<ProviderClientOrganizationSignUpResponse> SignUpClientOrganizationAsync(OrganizationSignup signup)
{
var plan = await _pricingClient.GetPlanOrThrow(signup.Plan);
ValidatePlan(plan, signup.AdditionalSeats);
var organization = new Organization
{
// Pre-generate the org id so that we can save it with the Stripe subscription.
Id = CoreHelpers.GenerateComb(),
Name = signup.Name,
BillingEmail = signup.BillingEmail,
PlanType = plan!.Type,
Seats = signup.AdditionalSeats,
MaxCollections = plan.PasswordManager.MaxCollections,
MaxStorageGb = 1,
UsePolicies = plan.HasPolicies,
UseSso = plan.HasSso,
UseOrganizationDomains = plan.HasOrganizationDomains,
UseGroups = plan.HasGroups,
UseEvents = plan.HasEvents,
UseDirectory = plan.HasDirectory,
UseTotp = plan.HasTotp,
Use2fa = plan.Has2fa,
UseApi = plan.HasApi,
UseResetPassword = plan.HasResetPassword,
SelfHost = plan.HasSelfHost,
UsersGetPremium = plan.UsersGetPremium,
UseCustomPermissions = plan.HasCustomPermissions,
UseScim = plan.HasScim,
Plan = plan.Name,
Gateway = GatewayType.Stripe,
ReferenceData = signup.Owner.ReferenceData,
Enabled = true,
LicenseKey = CoreHelpers.SecureRandomString(20),
PublicKey = signup.PublicKey,
PrivateKey = signup.PrivateKey,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
Status = OrganizationStatusType.Created,
UsePasswordManager = true,
// Secrets Manager not available for purchase with Consolidated Billing.
UseSecretsManager = false,
};
var returnValue = await SignUpAsync(organization, signup.CollectionName);
await _referenceEventService.RaiseEventAsync(
new ReferenceEvent(ReferenceEventType.Signup, organization, _currentContext)
{
PlanName = plan.Name,
PlanType = plan.Type,
Seats = returnValue.Organization.Seats,
SignupInitiationPath = signup.InitiationPath,
Storage = returnValue.Organization.MaxStorageGb,
});
return returnValue;
}
private static void ValidatePlan(Plan plan, int additionalSeats)
{
if (plan is null)
{
throw new BadRequestException(PlanNullErrorMessage);
}
if (plan.Disabled)
{
throw new BadRequestException(PlanDisabledErrorMessage);
}
if (additionalSeats < 0)
{
throw new BadRequestException(AdditionalSeatsNegativeErrorMessage);
}
}
/// <summary>
/// Private helper method to create a new organization.
/// </summary>
private async Task<ProviderClientOrganizationSignUpResponse> SignUpAsync(
Organization organization, string collectionName)
{
try
{
await _organizationRepository.CreateAsync(organization);
await _organizationApiKeyRepository.CreateAsync(new OrganizationApiKey
{
OrganizationId = organization.Id,
ApiKey = CoreHelpers.SecureRandomString(30),
Type = OrganizationApiKeyType.Default,
RevisionDate = DateTime.UtcNow,
});
await _applicationCacheService.UpsertOrganizationAbilityAsync(organization);
Collection defaultCollection = null;
if (!string.IsNullOrWhiteSpace(collectionName))
{
defaultCollection = new Collection
{
Name = collectionName,
OrganizationId = organization.Id,
CreationDate = organization.CreationDate,
RevisionDate = organization.CreationDate
};
await _collectionRepository.CreateAsync(defaultCollection, null, null);
}
return new ProviderClientOrganizationSignUpResponse(organization, defaultCollection);
}
catch
{
if (organization.Id != default)
{
await _organizationRepository.DeleteAsync(organization);
await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id);
}
throw;
}
}
}

View File

@ -11,8 +11,6 @@ namespace Bit.Core.Services;
public interface IOrganizationService public interface IOrganizationService
{ {
Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType,
TaxInfo taxInfo);
Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null); Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null);
Task ReinstateSubscriptionAsync(Guid organizationId); Task ReinstateSubscriptionAsync(Guid organizationId);
Task<string> AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb); Task<string> AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb);
@ -20,9 +18,6 @@ public interface IOrganizationService
Task AutoAddSeatsAsync(Organization organization, int seatsToAdd); Task AutoAddSeatsAsync(Organization organization, int seatsToAdd);
Task<string> AdjustSeatsAsync(Guid organizationId, int seatAdjustment); Task<string> AdjustSeatsAsync(Guid organizationId, int seatAdjustment);
Task VerifyBankAsync(Guid organizationId, int amount1, int amount2); Task VerifyBankAsync(Guid organizationId, int amount1, int amount2);
#nullable enable
Task<(Organization organization, OrganizationUser organizationUser, Collection defaultCollection)> SignupClientAsync(OrganizationSignup signup);
#nullable disable
/// <summary> /// <summary>
/// Create a new organization on a self-hosted instance /// Create a new organization on a self-hosted instance
/// </summary> /// </summary>

View File

@ -144,27 +144,6 @@ public class OrganizationService : IOrganizationService
_sendOrganizationInvitesCommand = sendOrganizationInvitesCommand; _sendOrganizationInvitesCommand = sendOrganizationInvitesCommand;
} }
public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken,
PaymentMethodType paymentMethodType, TaxInfo taxInfo)
{
var organization = await GetOrgById(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
await _paymentService.SaveTaxInfoAsync(organization, taxInfo);
var updated = await _paymentService.UpdatePaymentMethodAsync(
organization,
paymentMethodType,
paymentToken,
taxInfo);
if (updated)
{
await ReplaceAndUpdateCacheAsync(organization);
}
}
public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null) public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null)
{ {
var organization = await GetOrgById(organizationId); var organization = await GetOrgById(organizationId);
@ -431,66 +410,6 @@ public class OrganizationService : IOrganizationService
} }
} }
public async Task<(Organization organization, OrganizationUser organizationUser, Collection defaultCollection)> SignupClientAsync(OrganizationSignup signup)
{
var plan = await _pricingClient.GetPlanOrThrow(signup.Plan);
ValidatePlan(plan, signup.AdditionalSeats, "Password Manager");
var organization = new Organization
{
// Pre-generate the org id so that we can save it with the Stripe subscription.
Id = CoreHelpers.GenerateComb(),
Name = signup.Name,
BillingEmail = signup.BillingEmail,
PlanType = plan!.Type,
Seats = signup.AdditionalSeats,
MaxCollections = plan.PasswordManager.MaxCollections,
MaxStorageGb = 1,
UsePolicies = plan.HasPolicies,
UseSso = plan.HasSso,
UseOrganizationDomains = plan.HasOrganizationDomains,
UseGroups = plan.HasGroups,
UseEvents = plan.HasEvents,
UseDirectory = plan.HasDirectory,
UseTotp = plan.HasTotp,
Use2fa = plan.Has2fa,
UseApi = plan.HasApi,
UseResetPassword = plan.HasResetPassword,
SelfHost = plan.HasSelfHost,
UsersGetPremium = plan.UsersGetPremium,
UseCustomPermissions = plan.HasCustomPermissions,
UseScim = plan.HasScim,
Plan = plan.Name,
Gateway = GatewayType.Stripe,
ReferenceData = signup.Owner.ReferenceData,
Enabled = true,
LicenseKey = CoreHelpers.SecureRandomString(20),
PublicKey = signup.PublicKey,
PrivateKey = signup.PrivateKey,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
Status = OrganizationStatusType.Created,
UsePasswordManager = true,
// Secrets Manager not available for purchase with Consolidated Billing.
UseSecretsManager = false,
};
var returnValue = await SignUpAsync(organization, default, signup.OwnerKey, signup.CollectionName, false);
await _referenceEventService.RaiseEventAsync(
new ReferenceEvent(ReferenceEventType.Signup, organization, _currentContext)
{
PlanName = plan.Name,
PlanType = plan.Type,
Seats = returnValue.Item1.Seats,
SignupInitiationPath = signup.InitiationPath,
Storage = returnValue.Item1.MaxStorageGb,
});
return returnValue;
}
private async Task ValidateSignUpPoliciesAsync(Guid ownerId) private async Task ValidateSignUpPoliciesAsync(Guid ownerId)
{ {
var anySingleOrgPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(ownerId, PolicyType.SingleOrg); var anySingleOrgPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(ownerId, PolicyType.SingleOrg);
@ -572,6 +491,7 @@ public class OrganizationService : IOrganizationService
SmServiceAccounts = license.SmServiceAccounts, SmServiceAccounts = license.SmServiceAccounts,
UseRiskInsights = license.UseRiskInsights, UseRiskInsights = license.UseRiskInsights,
UseOrganizationDomains = license.UseOrganizationDomains, UseOrganizationDomains = license.UseOrganizationDomains,
UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies,
}; };
var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false); var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false);

View File

@ -2,9 +2,24 @@
public enum EmergencyAccessStatusType : byte public enum EmergencyAccessStatusType : byte
{ {
/// <summary>
/// The user has been invited to be an emergency contact.
/// </summary>
Invited = 0, Invited = 0,
/// <summary>
/// The invited user, "grantee", has accepted the request to be an emergency contact.
/// </summary>
Accepted = 1, Accepted = 1,
/// <summary>
/// The inviting user, "grantor", has approved the grantee's acceptance.
/// </summary>
Confirmed = 2, Confirmed = 2,
/// <summary>
/// The grantee has initiated the recovery process.
/// </summary>
RecoveryInitiated = 3, RecoveryInitiated = 3,
/// <summary>
/// The grantee has excercised their emergency access.
/// </summary>
RecoveryApproved = 4, RecoveryApproved = 4,
} }

View File

@ -3,6 +3,7 @@ using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums; using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Models.Data;
@ -20,6 +21,15 @@ public interface IEmergencyAccessService
Task InitiateAsync(Guid id, User initiatingUser); Task InitiateAsync(Guid id, User initiatingUser);
Task ApproveAsync(Guid id, User approvingUser); Task ApproveAsync(Guid id, User approvingUser);
Task RejectAsync(Guid id, User rejectingUser); Task RejectAsync(Guid id, User rejectingUser);
/// <summary>
/// This request is made by the Grantee user to fetch the policies <see cref="Policy"/> for the Grantor User.
/// The Grantor User has to be the owner of the organization. <see cref="OrganizationUserType"/>
/// If the Grantor user has OrganizationUserType.Owner then the policies for the _Grantor_ user
/// are returned.
/// </summary>
/// <param name="id">EmergencyAccess.Id being acted on</param>
/// <param name="requestingUser">User making the request, this is the Grantee</param>
/// <returns>null if the GrantorUser is not an organization owner; A list of policies otherwise.</returns>
Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser); Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser);
Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser); Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser);
Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key); Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key);

View File

@ -3,7 +3,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities; using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums; using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities; using Bit.Core.Entities;
@ -16,7 +15,6 @@ using Bit.Core.Tokens;
using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Models.Data;
using Bit.Core.Vault.Repositories; using Bit.Core.Vault.Repositories;
using Bit.Core.Vault.Services; using Bit.Core.Vault.Services;
using Microsoft.AspNetCore.Identity;
namespace Bit.Core.Auth.Services; namespace Bit.Core.Auth.Services;
@ -31,8 +29,6 @@ public class EmergencyAccessService : IEmergencyAccessService
private readonly IMailService _mailService; private readonly IMailService _mailService;
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly GlobalSettings _globalSettings; private readonly GlobalSettings _globalSettings;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IOrganizationService _organizationService;
private readonly IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> _dataProtectorTokenizer; private readonly IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> _dataProtectorTokenizer;
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
@ -45,9 +41,7 @@ public class EmergencyAccessService : IEmergencyAccessService
ICipherService cipherService, ICipherService cipherService,
IMailService mailService, IMailService mailService,
IUserService userService, IUserService userService,
IPasswordHasher<User> passwordHasher,
GlobalSettings globalSettings, GlobalSettings globalSettings,
IOrganizationService organizationService,
IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> dataProtectorTokenizer, IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> dataProtectorTokenizer,
IRemoveOrganizationUserCommand removeOrganizationUserCommand) IRemoveOrganizationUserCommand removeOrganizationUserCommand)
{ {
@ -59,9 +53,7 @@ public class EmergencyAccessService : IEmergencyAccessService
_cipherService = cipherService; _cipherService = cipherService;
_mailService = mailService; _mailService = mailService;
_userService = userService; _userService = userService;
_passwordHasher = passwordHasher;
_globalSettings = globalSettings; _globalSettings = globalSettings;
_organizationService = organizationService;
_dataProtectorTokenizer = dataProtectorTokenizer; _dataProtectorTokenizer = dataProtectorTokenizer;
_removeOrganizationUserCommand = removeOrganizationUserCommand; _removeOrganizationUserCommand = removeOrganizationUserCommand;
} }
@ -126,7 +118,12 @@ public class EmergencyAccessService : IEmergencyAccessService
throw new BadRequestException("Emergency Access not valid."); throw new BadRequestException("Emergency Access not valid.");
} }
if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email)) if (!_dataProtectorTokenizer.TryUnprotect(token, out var data))
{
throw new BadRequestException("Invalid token.");
}
if (!data.IsValid(emergencyAccessId, user.Email))
{ {
throw new BadRequestException("Invalid token."); throw new BadRequestException("Invalid token.");
} }
@ -140,6 +137,8 @@ public class EmergencyAccessService : IEmergencyAccessService
throw new BadRequestException("Invitation already accepted."); throw new BadRequestException("Invitation already accepted.");
} }
// TODO PM-21687
// Might not be reachable since the Tokenable.IsValid() does an email comparison
if (string.IsNullOrWhiteSpace(emergencyAccess.Email) || if (string.IsNullOrWhiteSpace(emergencyAccess.Email) ||
!emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) !emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
{ {
@ -163,6 +162,8 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId) public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId)
{ {
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
// TODO PM-19438/PM-21687
// Not sure why the GrantorId and the GranteeId are supposed to be the same?
if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId)) if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId))
{ {
throw new BadRequestException("Emergency Access not valid."); throw new BadRequestException("Emergency Access not valid.");
@ -171,9 +172,9 @@ public class EmergencyAccessService : IEmergencyAccessService
await _emergencyAccessRepository.DeleteAsync(emergencyAccess); await _emergencyAccessRepository.DeleteAsync(emergencyAccess);
} }
public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId) public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAccessId, string key, Guid confirmingUserId)
{ {
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId); var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted || if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted ||
emergencyAccess.GrantorId != confirmingUserId) emergencyAccess.GrantorId != confirmingUserId)
{ {
@ -224,7 +225,6 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task InitiateAsync(Guid id, User initiatingUser) public async Task InitiateAsync(Guid id, User initiatingUser)
{ {
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id || if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id ||
emergencyAccess.Status != EmergencyAccessStatusType.Confirmed) emergencyAccess.Status != EmergencyAccessStatusType.Confirmed)
{ {
@ -285,6 +285,9 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser) public async Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser)
{ {
// TODO PM-21687
// Should we look up policies here or just verify the EmergencyAccess is correct
// and handle policy logic else where? Should this be a query/Command?
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
@ -295,7 +298,9 @@ public class EmergencyAccessService : IEmergencyAccessService
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id);
var isOrganizationOwner = grantorOrganizations.Any<OrganizationUser>(organization => organization.Type == OrganizationUserType.Owner); var isOrganizationOwner = grantorOrganizations
.Any(organization => organization.Type == OrganizationUserType.Owner);
var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null; var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null;
return policies; return policies;
@ -311,7 +316,8 @@ public class EmergencyAccessService : IEmergencyAccessService
} }
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
// TODO PM-21687
// Redundant check of the EmergencyAccessType -> checked in IsValidRequest() ln 308
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
{ {
throw new BadRequestException("You cannot takeover an account that is using Key Connector."); throw new BadRequestException("You cannot takeover an account that is using Key Connector.");
@ -336,7 +342,9 @@ public class EmergencyAccessService : IEmergencyAccessService
grantor.LastPasswordChangeDate = grantor.RevisionDate; grantor.LastPasswordChangeDate = grantor.RevisionDate;
grantor.Key = key; grantor.Key = key;
// Disable TwoFactor providers since they will otherwise block logins // Disable TwoFactor providers since they will otherwise block logins
grantor.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>()); grantor.SetTwoFactorProviders([]);
// Disable New Device Verification since it will otherwise block logins
grantor.VerifyDevices = false;
await _userRepository.ReplaceAsync(grantor); await _userRepository.ReplaceAsync(grantor);
// Remove grantor from all organizations unless Owner // Remove grantor from all organizations unless Owner
@ -421,12 +429,22 @@ public class EmergencyAccessService : IEmergencyAccessService
await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token); await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token);
} }
private string NameOrEmail(User user) private static string NameOrEmail(User user)
{ {
return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name; return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name;
} }
private bool IsValidRequest(EmergencyAccess availableAccess, User requestingUser, EmergencyAccessType requestedAccessType)
/*
* Checks if EmergencyAccess Object is null
* Checks the requesting user is the same as the granteeUser (So we are checking for proper grantee action)
* Status _must_ equal RecoveryApproved (This means the grantor has invited, the grantee has accepted, and the grantor has approved so the shared key exists but hasn't been exercised yet)
* request type must equal the type of access requested (View or Takeover)
*/
private static bool IsValidRequest(
EmergencyAccess availableAccess,
User requestingUser,
EmergencyAccessType requestedAccessType)
{ {
return availableAccess != null && return availableAccess != null &&
availableAccess.GranteeId == requestingUser.Id && availableAccess.GranteeId == requestingUser.Id &&

View File

@ -2,10 +2,6 @@
public static class StripeConstants public static class StripeConstants
{ {
public static class Prices
{
public const string StoragePlanPersonal = "personal-storage-gb-annually";
}
public static class AutomaticTaxStatus public static class AutomaticTaxStatus
{ {
public const string Failed = "failed"; public const string Failed = "failed";
@ -69,6 +65,11 @@ public static class StripeConstants
public const string USBankAccount = "us_bank_account"; public const string USBankAccount = "us_bank_account";
} }
public static class Prices
{
public const string StoragePlanPersonal = "personal-storage-gb-annually";
}
public static class ProrationBehavior public static class ProrationBehavior
{ {
public const string AlwaysInvoice = "always_invoice"; public const string AlwaysInvoice = "always_invoice";
@ -88,6 +89,13 @@ public static class StripeConstants
public const string Paused = "paused"; public const string Paused = "paused";
} }
public static class TaxExempt
{
public const string Exempt = "exempt";
public const string None = "none";
public const string Reverse = "reverse";
}
public static class ValidateTaxLocationTiming public static class ValidateTaxLocationTiming
{ {
public const string Deferred = "deferred"; public const string Deferred = "deferred";

View File

@ -15,12 +15,7 @@ public static class CustomerExtensions
} }
}; };
/// <summary> public static bool HasRecognizedTaxLocation(this Customer customer) =>
/// Determines if a Stripe customer supports automatic tax
/// </summary>
/// <param name="customer"></param>
/// <returns></returns>
public static bool HasTaxLocationVerified(this Customer customer) =>
customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation; customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation;
public static decimal GetBillingBalance(this Customer customer) public static decimal GetBillingBalance(this Customer customer)

View File

@ -22,7 +22,7 @@ public static class SubscriptionUpdateOptionsExtensions
} }
// We might only need to check the automatic tax status. // We might only need to check the automatic tax status.
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country)) if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
{ {
return false; return false;
} }

View File

@ -22,7 +22,7 @@ public static class UpcomingInvoiceOptionsExtensions
} }
// We might only need to check the automatic tax status. // We might only need to check the automatic tax status.
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country)) if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
{ {
return false; return false;
} }

View File

@ -309,6 +309,7 @@ public class OrganizationMigrator(
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
organization.UsePolicies = plan.HasPolicies; organization.UsePolicies = plan.HasPolicies;
organization.UseSso = plan.HasSso; organization.UseSso = plan.HasSso;
organization.UseOrganizationDomains = plan.HasOrganizationDomains;
organization.UseGroups = plan.HasGroups; organization.UseGroups = plan.HasGroups;
organization.UseEvents = plan.HasEvents; organization.UseEvents = plan.HasEvents;
organization.UseDirectory = plan.HasDirectory; organization.UseDirectory = plan.HasDirectory;

View File

@ -26,6 +26,7 @@ public record Enterprise2019Plan : Plan
Has2fa = true; Has2fa = true;
HasApi = true; HasApi = true;
HasSso = true; HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true; HasKeyConnector = true;
HasScim = true; HasScim = true;
HasResetPassword = true; HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record Enterprise2020Plan : Plan
Has2fa = true; Has2fa = true;
HasApi = true; HasApi = true;
HasSso = true; HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true; HasKeyConnector = true;
HasScim = true; HasScim = true;
HasResetPassword = true; HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record EnterprisePlan : Plan
Has2fa = true; Has2fa = true;
HasApi = true; HasApi = true;
HasSso = true; HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true; HasKeyConnector = true;
HasScim = true; HasScim = true;
HasResetPassword = true; HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record Enterprise2023Plan : Plan
Has2fa = true; Has2fa = true;
HasApi = true; HasApi = true;
HasSso = true; HasSso = true;
HasOrganizationDomains = true;
HasKeyConnector = true; HasKeyConnector = true;
HasScim = true; HasScim = true;
HasResetPassword = true; HasResetPassword = true;

View File

@ -26,6 +26,7 @@ public record PlanAdapter : Plan
Has2fa = HasFeature("2fa"); Has2fa = HasFeature("2fa");
HasApi = HasFeature("api"); HasApi = HasFeature("api");
HasSso = HasFeature("sso"); HasSso = HasFeature("sso");
HasOrganizationDomains = HasFeature("organizationDomains");
HasKeyConnector = HasFeature("keyConnector"); HasKeyConnector = HasFeature("keyConnector");
HasScim = HasFeature("scim"); HasScim = HasFeature("scim");
HasResetPassword = HasFeature("resetPassword"); HasResetPassword = HasFeature("resetPassword");

View File

@ -1,11 +1,11 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Caches; using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services; using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums; using Bit.Core.Enums;
@ -35,16 +35,15 @@ public class OrganizationBillingService(
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService, ISubscriberService subscriberService,
ITaxService taxService, ITaxService taxService) : IOrganizationBillingService
IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService
{ {
public async Task Finalize(OrganizationSale sale) public async Task Finalize(OrganizationSale sale)
{ {
var (organization, customerSetup, subscriptionSetup) = sale; var (organization, customerSetup, subscriptionSetup) = sale;
var customer = string.IsNullOrEmpty(organization.GatewayCustomerId) && customerSetup != null var customer = string.IsNullOrEmpty(organization.GatewayCustomerId) && customerSetup != null
? await CreateCustomerAsync(organization, customerSetup) ? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType)
: await subscriberService.GetCustomerOrThrow(organization, new CustomerGetOptions { Expand = ["tax", "tax_ids"] }); : await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup);
var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup); var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup);
@ -121,7 +120,8 @@ public class OrganizationBillingService(
subscription.CurrentPeriodEnd); subscription.CurrentPeriodEnd);
} }
public async Task UpdatePaymentMethod( public async Task
UpdatePaymentMethod(
Organization organization, Organization organization,
TokenizedPaymentSource tokenizedPaymentSource, TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation) TaxInformation taxInformation)
@ -151,8 +151,11 @@ public class OrganizationBillingService(
private async Task<Customer> CreateCustomerAsync( private async Task<Customer> CreateCustomerAsync(
Organization organization, Organization organization,
CustomerSetup customerSetup) CustomerSetup customerSetup,
PlanType? updatedPlanType = null)
{ {
var planType = updatedPlanType ?? organization.PlanType;
var displayName = organization.DisplayName(); var displayName = organization.DisplayName();
var customerCreateOptions = new CustomerCreateOptions var customerCreateOptions = new CustomerCreateOptions
@ -212,13 +215,24 @@ public class OrganizationBillingService(
City = customerSetup.TaxInformation.City, City = customerSetup.TaxInformation.City,
PostalCode = customerSetup.TaxInformation.PostalCode, PostalCode = customerSetup.TaxInformation.PostalCode,
State = customerSetup.TaxInformation.State, State = customerSetup.TaxInformation.State,
Country = customerSetup.TaxInformation.Country, Country = customerSetup.TaxInformation.Country
}; };
customerCreateOptions.Tax = new CustomerTaxOptions customerCreateOptions.Tax = new CustomerTaxOptions
{ {
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
}; };
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge &&
planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families &&
customerSetup.TaxInformation.Country != "US")
{
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId)) if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId))
{ {
var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country, var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country,
@ -399,21 +413,68 @@ public class OrganizationBillingService(
TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays
}; };
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge)
{ {
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType); subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
} }
else else if (customer.HasRecognizedTaxLocation())
{ {
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions(); subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
subscriptionCreateOptions.AutomaticTax.Enabled = customer.HasBillingLocation(); {
Enabled =
subscriptionSetup.PlanType.GetProductTier() == ProductTierType.Families ||
customer.Address.Country == "US" ||
customer.TaxIds.Any()
};
} }
return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
} }
private async Task<Customer> GetCustomerWhileEnsuringCorrectTaxExemptionAsync(
Organization organization,
SubscriptionSetup subscriptionSetup)
{
var customer = await subscriberService.GetCustomerOrThrow(organization,
new CustomerGetOptions { Expand = ["tax", "tax_ids"] });
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (!setNonUSBusinessUseToReverseCharge || subscriptionSetup.PlanType.GetProductTier() is
not (ProductTierType.Teams or
ProductTierType.TeamsStarter or
ProductTierType.Enterprise))
{
return customer;
}
List<string> expansions = ["tax", "tax_ids"];
customer = customer switch
{
{ Address.Country: not "US", TaxExempt: not StripeConstants.TaxExempt.Reverse } => await
stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions
{
Expand = expansions,
TaxExempt = StripeConstants.TaxExempt.Reverse
}),
{ Address.Country: "US", TaxExempt: StripeConstants.TaxExempt.Reverse } => await
stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions
{
Expand = expansions,
TaxExempt = StripeConstants.TaxExempt.None
}),
_ => customer
};
return customer;
}
private async Task<bool> IsEligibleForSelfHostAsync( private async Task<bool> IsEligibleForSelfHostAsync(
Organization organization) Organization organization)
{ {

View File

@ -3,8 +3,6 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Billing.Tax.Services.Implementations;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
@ -12,7 +10,6 @@ using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Braintree; using Braintree;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Stripe; using Stripe;
using Customer = Stripe.Customer; using Customer = Stripe.Customer;
@ -24,20 +21,18 @@ using static Utilities;
public class PremiumUserBillingService( public class PremiumUserBillingService(
IBraintreeGateway braintreeGateway, IBraintreeGateway braintreeGateway,
IFeatureService featureService,
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
ILogger<PremiumUserBillingService> logger, ILogger<PremiumUserBillingService> logger,
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ISubscriberService subscriberService, ISubscriberService subscriberService,
IUserRepository userRepository, IUserRepository userRepository) : IPremiumUserBillingService
[FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService
{ {
public async Task Credit(User user, decimal amount) public async Task Credit(User user, decimal amount)
{ {
var customer = await subscriberService.GetCustomer(user); var customer = await subscriberService.GetCustomer(user);
// Negative credit represents a balance and all Stripe denomination is in cents. // Negative credit represents a balance, and all Stripe denomination is in cents.
var credit = (long)(amount * -100); var credit = (long)(amount * -100);
if (customer == null) if (customer == null)
@ -184,7 +179,7 @@ public class PremiumUserBillingService(
City = customerSetup.TaxInformation.City, City = customerSetup.TaxInformation.City,
PostalCode = customerSetup.TaxInformation.PostalCode, PostalCode = customerSetup.TaxInformation.PostalCode,
State = customerSetup.TaxInformation.State, State = customerSetup.TaxInformation.State,
Country = customerSetup.TaxInformation.Country, Country = customerSetup.TaxInformation.Country
}, },
Description = user.Name, Description = user.Name,
Email = user.Email, Email = user.Email,
@ -324,6 +319,10 @@ public class PremiumUserBillingService(
var subscriptionCreateOptions = new SubscriptionCreateOptions var subscriptionCreateOptions = new SubscriptionCreateOptions
{ {
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
Customer = customer.Id, Customer = customer.Id,
Items = subscriptionItemOptionsList, Items = subscriptionItemOptionsList,
@ -337,18 +336,6 @@ public class PremiumUserBillingService(
OffSession = true OffSession = true
}; };
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported,
};
}
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
if (usingPayPal) if (usingPayPal)
@ -380,7 +367,7 @@ public class PremiumUserBillingService(
City = taxInformation.City, City = taxInformation.City,
PostalCode = taxInformation.PostalCode, PostalCode = taxInformation.PostalCode,
State = taxInformation.State, State = taxInformation.State,
Country = taxInformation.Country, Country = taxInformation.Country
}, },
Expand = ["tax"], Expand = ["tax"],
Tax = new CustomerTaxOptions Tax = new CustomerTaxOptions

View File

@ -1,7 +1,10 @@
using Bit.Core.Billing.Caches; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services; using Bit.Core.Billing.Tax.Services;
using Bit.Core.Entities; using Bit.Core.Entities;
@ -28,8 +31,7 @@ public class SubscriberService(
ILogger<SubscriberService> logger, ILogger<SubscriberService> logger,
ISetupIntentCache setupIntentCache, ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter, IStripeAdapter stripeAdapter,
ITaxService taxService, ITaxService taxService) : ISubscriberService
IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService
{ {
public async Task CancelSubscription( public async Task CancelSubscription(
ISubscriber subscriber, ISubscriber subscriber,
@ -128,7 +130,7 @@ public class SubscriberService(
[subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion [subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion
}, },
Email = subscriber.BillingEmailAddress(), Email = subscriber.BillingEmailAddress(),
PaymentMethodNonce = paymentMethodNonce, PaymentMethodNonce = paymentMethodNonce
}); });
if (customerResult.IsSuccess()) if (customerResult.IsSuccess())
@ -482,7 +484,7 @@ public class SubscriberService(
var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First(); var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First();
// Find the customer's existing setup intents that should be cancelled. // Find the customer's existing setup intents that should be canceled.
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer) var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
.Where(si => .Where(si =>
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action"); si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
@ -519,7 +521,7 @@ public class SubscriberService(
await stripeAdapter.PaymentMethodAttachAsync(token, await stripeAdapter.PaymentMethodAttachAsync(token,
new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId }); new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId });
// Find the customer's existing setup intents that should be cancelled. // Find the customer's existing setup intents that should be canceled.
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer) var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
.Where(si => .Where(si =>
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action"); si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
@ -637,7 +639,8 @@ public class SubscriberService(
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
taxInformation.Country, taxInformation.Country,
taxInformation.TaxId); taxInformation.TaxId);
throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError");
throw new BadRequestException("billingTaxIdTypeInferenceError");
} }
} }
@ -654,53 +657,84 @@ public class SubscriberService(
logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.",
taxInformation.TaxId, taxInformation.TaxId,
taxInformation.Country); taxInformation.Country);
throw new Exceptions.BadRequestException("billingInvalidTaxIdError");
throw new BadRequestException("billingInvalidTaxIdError");
default: default:
logger.LogError(e, logger.LogError(e,
"Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.", "Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.",
taxInformation.TaxId, taxInformation.TaxId,
taxInformation.Country, taxInformation.Country,
customer.Id); customer.Id);
throw new Exceptions.BadRequestException("billingTaxIdCreationError");
throw new BadRequestException("billingTaxIdCreationError");
} }
} }
} }
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var subscription =
customer.Subscriptions.First(subscription => subscription.Id == subscriber.GatewaySubscriptionId);
var isBusinessUseSubscriber = subscriber switch
{ {
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) Organization organization => organization.PlanType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families,
Provider => true,
_ => false
};
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge && isBusinessUseSubscriber)
{
switch (customer)
{ {
var subscriptionGetOptions = new SubscriptionGetOptions case
{ {
Expand = ["customer.tax", "customer.tax_ids"] Address.Country: not "US",
}; TaxExempt: not StripeConstants.TaxExempt.Reverse
var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); }:
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); await stripeAdapter.CustomerUpdateAsync(customer.Id,
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription); break;
if (automaticTaxOptions?.AutomaticTax?.Enabled != null) case
{ {
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions); Address.Country: "US",
} TaxExempt: StripeConstants.TaxExempt.Reverse
}:
await stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.None });
break;
} }
}
else if (!subscription.AutomaticTax.Enabled)
{
if (SubscriberIsEligibleForAutomaticTax(subscriber, customer))
{ {
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, await stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
new SubscriptionUpdateOptions new SubscriptionUpdateOptions
{ {
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
}); });
} }
}
else
{
var automaticTaxShouldBeEnabled = subscriber switch
{
User => true,
Organization organization => organization.PlanType.GetProductTier() == ProductTierType.Families ||
customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false),
Provider => customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false),
_ => false
};
return; if (automaticTaxShouldBeEnabled && !subscription.AutomaticTax.Enabled)
{
bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer) await stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
=> !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) && new SubscriptionUpdateOptions
(localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) && {
localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
} }
} }

View File

@ -76,7 +76,7 @@ public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : I
private bool ShouldBeEnabled(Customer customer) private bool ShouldBeEnabled(Customer customer)
{ {
if (!customer.HasTaxLocationVerified()) if (!customer.HasRecognizedTaxLocation())
{ {
return false; return false;
} }

View File

@ -59,6 +59,6 @@ public class PersonalUseAutomaticTaxStrategy(IFeatureService featureService) : I
private static bool ShouldBeEnabled(Customer customer) private static bool ShouldBeEnabled(Customer customer)
{ {
return customer.HasTaxLocationVerified(); return customer.HasRecognizedTaxLocation();
} }
} }

View File

@ -115,6 +115,7 @@ public static class FeatureFlagKeys
public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence"; public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence";
public const string EmailVerification = "email-verification"; public const string EmailVerification = "email-verification";
public const string UnauthenticatedExtensionUIRefresh = "unauth-ui-refresh"; public const string UnauthenticatedExtensionUIRefresh = "unauth-ui-refresh";
public const string BrowserExtensionLoginApproval = "pm-14938-browser-extension-login-approvals";
public const string SetInitialPasswordRefactor = "pm-16117-set-initial-password-refactor"; public const string SetInitialPasswordRefactor = "pm-16117-set-initial-password-refactor";
public const string ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor"; public const string ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor";
public const string RecoveryCodeLogin = "pm-17128-recovery-code-login"; public const string RecoveryCodeLogin = "pm-17128-recovery-code-login";
@ -142,13 +143,13 @@ public static class FeatureFlagKeys
public const string UsePricingService = "use-pricing-service"; public const string UsePricingService = "use-pricing-service";
public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features"; public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features";
public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method"; public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method";
public const string PM19147_AutomaticTaxImprovements = "pm-19147-automatic-tax-improvements";
public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates";
public const string PM18770_EnableOrganizationBusinessUnitConversion = "pm-18770-enable-organization-business-unit-conversion"; public const string PM18770_EnableOrganizationBusinessUnitConversion = "pm-18770-enable-organization-business-unit-conversion";
public const string PM199566_UpdateMSPToChargeAutomatically = "pm-199566-update-msp-to-charge-automatically"; public const string PM199566_UpdateMSPToChargeAutomatically = "pm-199566-update-msp-to-charge-automatically";
public const string PM19956_RequireProviderPaymentMethodDuringSetup = "pm-19956-require-provider-payment-method-during-setup"; public const string PM19956_RequireProviderPaymentMethodDuringSetup = "pm-19956-require-provider-payment-method-during-setup";
public const string UseOrganizationWarningsService = "use-organization-warnings-service"; public const string UseOrganizationWarningsService = "use-organization-warnings-service";
public const string PM20322_AllowTrialLength0 = "pm-20322-allow-trial-length-0"; public const string PM20322_AllowTrialLength0 = "pm-20322-allow-trial-length-0";
public const string PM21092_SetNonUSBusinessUseToReverseCharge = "pm-21092-set-non-us-business-use-to-reverse-charge";
/* Data Insights and Reporting Team */ /* Data Insights and Reporting Team */
public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application"; public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application";

View File

@ -84,6 +84,7 @@ public class OrganizationLicense : ILicense
SmSeats = org.SmSeats; SmSeats = org.SmSeats;
SmServiceAccounts = org.SmServiceAccounts; SmServiceAccounts = org.SmServiceAccounts;
UseRiskInsights = org.UseRiskInsights; UseRiskInsights = org.UseRiskInsights;
UseOrganizationDomains = org.UseOrganizationDomains;
// Deprecated. Left for backwards compatibility with old license versions. // Deprecated. Left for backwards compatibility with old license versions.
LimitCollectionCreationDeletion = org.LimitCollectionCreation || org.LimitCollectionDeletion; LimitCollectionCreationDeletion = org.LimitCollectionCreation || org.LimitCollectionDeletion;
@ -195,10 +196,10 @@ public class OrganizationLicense : ILicense
/// <remarks>Intentionally set one version behind to allow self hosted users some time to update before /// <remarks>Intentionally set one version behind to allow self hosted users some time to update before
/// getting out of date license errors /// getting out of date license errors
/// </remarks> /// </remarks>
public const int CurrentLicenseFileVersion = 14; public const int CurrentLicenseFileVersion = 15;
private bool ValidLicenseVersion private bool ValidLicenseVersion
{ {
get => Version is >= 1 and <= 15; get => Version is >= 1 and <= 16;
} }
public byte[] GetDataBytes(bool forHash = false) public byte[] GetDataBytes(bool forHash = false)
@ -244,6 +245,8 @@ public class OrganizationLicense : ILicense
(Version >= 14 || !p.Name.Equals(nameof(LimitCollectionCreationDeletion))) && (Version >= 14 || !p.Name.Equals(nameof(LimitCollectionCreationDeletion))) &&
// AllowAdminAccessToAllCollectionItems was added in Version 15 // AllowAdminAccessToAllCollectionItems was added in Version 15
(Version >= 15 || !p.Name.Equals(nameof(AllowAdminAccessToAllCollectionItems))) && (Version >= 15 || !p.Name.Equals(nameof(AllowAdminAccessToAllCollectionItems))) &&
// UseOrganizationDomains was added in Version 16
(Version >= 16 || !p.Name.Equals(nameof(UseOrganizationDomains))) &&
( (
!forHash || !forHash ||
( (
@ -252,7 +255,10 @@ public class OrganizationLicense : ILicense
!p.Name.Equals(nameof(Refresh)) !p.Name.Equals(nameof(Refresh))
) )
) && ) &&
!p.Name.Equals(nameof(UseRiskInsights))) // any new fields added need to be added here so that they're ignored
!p.Name.Equals(nameof(UseRiskInsights)) &&
!p.Name.Equals(nameof(UseAdminSponsoredFamilies)) &&
!p.Name.Equals(nameof(UseOrganizationDomains)))
.OrderBy(p => p.Name) .OrderBy(p => p.Name)
.Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}")
.Aggregate((c, n) => $"{c}|{n}"); .Aggregate((c, n) => $"{c}|{n}");
@ -583,6 +589,11 @@ public class OrganizationLicense : ILicense
* validation. * validation.
*/ */
if (valid && Version >= 16)
{
valid = organization.UseOrganizationDomains;
}
return valid; return valid;
} }

View File

@ -69,8 +69,11 @@ public static class OrganizationServiceCollectionExtensions
services.AddBaseOrganizationSubscriptionCommandsQueries(); services.AddBaseOrganizationSubscriptionCommandsQueries();
} }
private static IServiceCollection AddOrganizationSignUpCommands(this IServiceCollection services) => private static void AddOrganizationSignUpCommands(this IServiceCollection services)
{
services.AddScoped<ICloudOrganizationSignUpCommand, CloudOrganizationSignUpCommand>(); services.AddScoped<ICloudOrganizationSignUpCommand, CloudOrganizationSignUpCommand>();
services.AddScoped<IProviderClientOrganizationSignUpCommand, ProviderClientOrganizationSignUpCommand>();
}
private static void AddOrganizationDeleteCommands(this IServiceCollection services) private static void AddOrganizationDeleteCommands(this IServiceCollection services)
{ {

View File

@ -4,7 +4,6 @@ using Bit.Core.Billing.Models;
using Bit.Core.Billing.Tax.Requests; using Bit.Core.Billing.Tax.Requests;
using Bit.Core.Billing.Tax.Responses; using Bit.Core.Billing.Tax.Responses;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
using Bit.Core.Models.StaticStore; using Bit.Core.Models.StaticStore;
@ -30,8 +29,6 @@ public interface IPaymentService
Task<string> AdjustServiceAccountsAsync(Organization organization, Plan plan, int additionalServiceAccounts); Task<string> AdjustServiceAccountsAsync(Organization organization, Plan plan, int additionalServiceAccounts);
Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false); Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false);
Task ReinstateSubscriptionAsync(ISubscriber subscriber); Task ReinstateSubscriptionAsync(ISubscriber subscriber);
Task<bool> UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType,
string paymentToken, TaxInfo taxInfo = null);
Task<bool> CreditAccountAsync(ISubscriber subscriber, decimal creditAmount); Task<bool> CreditAccountAsync(ISubscriber subscriber, decimal creditAmount);
Task<BillingInfo> GetBillingAsync(ISubscriber subscriber); Task<BillingInfo> GetBillingAsync(ISubscriber subscriber);
Task<BillingHistoryInfo> GetBillingHistoryAsync(ISubscriber subscriber); Task<BillingHistoryInfo> GetBillingHistoryAsync(ISubscriber subscriber);

View File

@ -1,13 +1,13 @@
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Requests; using Bit.Core.Billing.Tax.Requests;
using Bit.Core.Billing.Tax.Responses; using Bit.Core.Billing.Tax.Responses;
using Bit.Core.Billing.Tax.Services; using Bit.Core.Billing.Tax.Services;
@ -38,7 +38,6 @@ public class StripePaymentService : IPaymentService
private readonly IGlobalSettings _globalSettings; private readonly IGlobalSettings _globalSettings;
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
private readonly ITaxService _taxService; private readonly ITaxService _taxService;
private readonly ISubscriberService _subscriberService;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
private readonly IAutomaticTaxFactory _automaticTaxFactory; private readonly IAutomaticTaxFactory _automaticTaxFactory;
private readonly IAutomaticTaxStrategy _personalUseTaxStrategy; private readonly IAutomaticTaxStrategy _personalUseTaxStrategy;
@ -51,7 +50,6 @@ public class StripePaymentService : IPaymentService
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
IFeatureService featureService, IFeatureService featureService,
ITaxService taxService, ITaxService taxService,
ISubscriberService subscriberService,
IPricingClient pricingClient, IPricingClient pricingClient,
IAutomaticTaxFactory automaticTaxFactory, IAutomaticTaxFactory automaticTaxFactory,
[FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy personalUseTaxStrategy) [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy personalUseTaxStrategy)
@ -63,7 +61,6 @@ public class StripePaymentService : IPaymentService
_globalSettings = globalSettings; _globalSettings = globalSettings;
_featureService = featureService; _featureService = featureService;
_taxService = taxService; _taxService = taxService;
_subscriberService = subscriberService;
_pricingClient = pricingClient; _pricingClient = pricingClient;
_automaticTaxFactory = automaticTaxFactory; _automaticTaxFactory = automaticTaxFactory;
_personalUseTaxStrategy = personalUseTaxStrategy; _personalUseTaxStrategy = personalUseTaxStrategy;
@ -136,15 +133,68 @@ public class StripePaymentService : IPaymentService
if (subscriptionUpdate is CompleteSubscriptionUpdate) if (subscriptionUpdate is CompleteSubscriptionUpdate)
{ {
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) var setNonUSBusinessUseToReverseCharge =
_featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge)
{ {
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, updatedItemOptions.Select(x => x.Plan ?? x.Price)); if (sub.Customer is
var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); {
automaticTaxStrategy.SetUpdateOptions(subUpdateOptions, sub); Address.Country: not "US",
TaxExempt: not StripeConstants.TaxExempt.Reverse
})
{
await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
}
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
} }
else else if (sub.Customer.HasRecognizedTaxLocation())
{ {
subUpdateOptions.EnableAutomaticTax(sub.Customer, sub); switch (subscriber)
{
case User:
{
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
break;
}
case Organization:
{
if (sub.Customer.Address.Country == "US")
{
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
}
else
{
var familyPriceIds = (await Task.WhenAll(
_pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019),
_pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually)))
.Select(plan => plan.PasswordManager.StripePlanId);
var updateIsForPersonalUse = updatedItemOptions
.Select(option => option.Price)
.Intersect(familyPriceIds)
.Any();
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = updateIsForPersonalUse || sub.Customer.TaxIds.Any()
};
}
break;
}
case Provider:
{
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = sub.Customer.Address.Country == "US" ||
sub.Customer.TaxIds.Any()
};
break;
}
}
} }
} }
@ -202,7 +252,7 @@ public class StripePaymentService : IPaymentService
} }
else if (!invoice.Paid) else if (!invoice.Paid)
{ {
// Pay invoice with no charge to customer this completes the invoice immediately without waiting the scheduled 1h // Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h
invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId);
paymentIntentClientSecret = null; paymentIntentClientSecret = null;
} }
@ -585,309 +635,6 @@ public class StripePaymentService : IPaymentService
} }
} }
public async Task<bool> UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType,
string paymentToken, TaxInfo taxInfo = null)
{
if (subscriber == null)
{
throw new ArgumentNullException(nameof(subscriber));
}
if (subscriber.Gateway.HasValue && subscriber.Gateway.Value != GatewayType.Stripe)
{
throw new GatewayException("Switching from one payment type to another is not supported. " +
"Contact us for assistance.");
}
var createdCustomer = false;
Braintree.Customer braintreeCustomer = null;
string stipeCustomerSourceToken = null;
string stipeCustomerPaymentMethodId = null;
var stripeCustomerMetadata = new Dictionary<string, string>
{
{ "region", _globalSettings.BaseServiceUri.CloudRegion }
};
var stripePaymentMethod = paymentMethodType is PaymentMethodType.Card or PaymentMethodType.BankAccount;
Customer customer = null;
if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
{
var options = new CustomerGetOptions { Expand = ["sources", "tax", "subscriptions"] };
customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, options);
if (customer.Metadata?.Any() ?? false)
{
stripeCustomerMetadata = customer.Metadata;
}
}
var hadBtCustomer = stripeCustomerMetadata.ContainsKey("btCustomerId");
if (stripePaymentMethod)
{
if (paymentToken.StartsWith("pm_"))
{
stipeCustomerPaymentMethodId = paymentToken;
}
else
{
stipeCustomerSourceToken = paymentToken;
}
}
else if (paymentMethodType == PaymentMethodType.PayPal)
{
if (hadBtCustomer)
{
var pmResult = await _btGateway.PaymentMethod.CreateAsync(new Braintree.PaymentMethodRequest
{
CustomerId = stripeCustomerMetadata["btCustomerId"],
PaymentMethodNonce = paymentToken
});
if (pmResult.IsSuccess())
{
var customerResult = await _btGateway.Customer.UpdateAsync(
stripeCustomerMetadata["btCustomerId"], new Braintree.CustomerRequest
{
DefaultPaymentMethodToken = pmResult.Target.Token
});
if (customerResult.IsSuccess() && customerResult.Target.PaymentMethods.Length > 0)
{
braintreeCustomer = customerResult.Target;
}
else
{
await _btGateway.PaymentMethod.DeleteAsync(pmResult.Target.Token);
hadBtCustomer = false;
}
}
else
{
hadBtCustomer = false;
}
}
if (!hadBtCustomer)
{
var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest
{
PaymentMethodNonce = paymentToken,
Email = subscriber.BillingEmailAddress(),
Id = subscriber.BraintreeCustomerIdPrefix() + subscriber.Id.ToString("N").ToLower() +
Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false),
CustomFields = new Dictionary<string, string>
{
[subscriber.BraintreeIdField()] = subscriber.Id.ToString(),
[subscriber.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion
}
});
if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0)
{
throw new GatewayException("Failed to create PayPal customer record.");
}
braintreeCustomer = customerResult.Target;
}
}
else
{
throw new GatewayException("Payment method is not supported at this time.");
}
if (stripeCustomerMetadata.ContainsKey("btCustomerId"))
{
if (braintreeCustomer?.Id != stripeCustomerMetadata["btCustomerId"])
{
stripeCustomerMetadata["btCustomerId_old"] = stripeCustomerMetadata["btCustomerId"];
}
stripeCustomerMetadata["btCustomerId"] = braintreeCustomer?.Id;
}
else if (!string.IsNullOrWhiteSpace(braintreeCustomer?.Id))
{
stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id);
}
try
{
if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber))
{
taxInfo.TaxIdType = taxInfo.TaxIdType ??
_taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber);
}
if (customer == null)
{
customer = await _stripeAdapter.CustomerCreateAsync(new CustomerCreateOptions
{
Description = subscriber.BillingName(),
Email = subscriber.BillingEmailAddress(),
Metadata = stripeCustomerMetadata,
Source = stipeCustomerSourceToken,
PaymentMethod = stipeCustomerPaymentMethodId,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
DefaultPaymentMethod = stipeCustomerPaymentMethodId,
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions()
{
Name = subscriber.SubscriberType(),
Value = subscriber.GetFormattedInvoiceName()
}
]
},
Address = taxInfo == null ? null : new AddressOptions
{
Country = taxInfo.BillingAddressCountry,
PostalCode = taxInfo.BillingAddressPostalCode,
Line1 = taxInfo.BillingAddressLine1 ?? string.Empty,
Line2 = taxInfo.BillingAddressLine2,
City = taxInfo.BillingAddressCity,
State = taxInfo.BillingAddressState
},
TaxIdData = string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber)
? []
: [
new CustomerTaxIdDataOptions
{
Type = taxInfo.TaxIdType,
Value = taxInfo.TaxIdNumber
}
],
Expand = ["sources", "tax", "subscriptions"],
});
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = customer.Id;
createdCustomer = true;
}
if (!createdCustomer)
{
string defaultSourceId = null;
string defaultPaymentMethodId = null;
if (stripePaymentMethod)
{
if (!string.IsNullOrWhiteSpace(stipeCustomerSourceToken) && paymentToken.StartsWith("btok_"))
{
var bankAccount = await _stripeAdapter.BankAccountCreateAsync(customer.Id, new BankAccountCreateOptions
{
Source = paymentToken
});
defaultSourceId = bankAccount.Id;
}
else if (!string.IsNullOrWhiteSpace(stipeCustomerPaymentMethodId))
{
await _stripeAdapter.PaymentMethodAttachAsync(stipeCustomerPaymentMethodId,
new PaymentMethodAttachOptions { Customer = customer.Id });
defaultPaymentMethodId = stipeCustomerPaymentMethodId;
}
}
if (customer.Sources != null)
{
foreach (var source in customer.Sources.Where(s => s.Id != defaultSourceId))
{
if (source is BankAccount)
{
await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id);
}
else if (source is Card)
{
await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id);
}
}
}
var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging(new PaymentMethodListOptions
{
Customer = customer.Id,
Type = "card"
});
foreach (var cardMethod in cardPaymentMethods.Where(m => m.Id != defaultPaymentMethodId))
{
await _stripeAdapter.PaymentMethodDetachAsync(cardMethod.Id, new PaymentMethodDetachOptions());
}
await _subscriberService.UpdateTaxInformation(subscriber, TaxInformation.From(taxInfo));
customer = await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
Metadata = stripeCustomerMetadata,
DefaultSource = defaultSourceId,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
DefaultPaymentMethod = defaultPaymentMethodId,
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions()
{
Name = subscriber.SubscriberType(),
Value = subscriber.GetFormattedInvoiceName()
}
]
},
Expand = ["tax", "subscriptions"]
});
}
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
{
var subscriptionGetOptions = new SubscriptionGetOptions
{
Expand = ["customer.tax", "customer.tax_ids"]
};
var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions);
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id));
var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters);
var subscriptionUpdateOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
if (subscriptionUpdateOptions != null)
{
_ = await _stripeAdapter.SubscriptionUpdateAsync(
subscriber.GatewaySubscriptionId,
subscriptionUpdateOptions);
}
}
}
else
{
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) &&
customer.Subscriptions.Any(sub =>
sub.Id == subscriber.GatewaySubscriptionId &&
!sub.AutomaticTax.Enabled) &&
customer.HasTaxLocationVerified())
{
var subscriptionUpdateOptions = new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
DefaultTaxRates = []
};
_ = await _stripeAdapter.SubscriptionUpdateAsync(
subscriber.GatewaySubscriptionId,
subscriptionUpdateOptions);
}
}
}
catch
{
if (braintreeCustomer != null && !hadBtCustomer)
{
await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id);
}
throw;
}
return createdCustomer;
}
public async Task<bool> CreditAccountAsync(ISubscriber subscriber, decimal creditAmount) public async Task<bool> CreditAccountAsync(ISubscriber subscriber, decimal creditAmount)
{ {
Customer customer = null; Customer customer = null;
@ -1018,7 +765,7 @@ public class StripePaymentService : IPaymentService
var address = customer.Address; var address = customer.Address;
var taxId = customer.TaxIds?.FirstOrDefault(); var taxId = customer.TaxIds?.FirstOrDefault();
// Line1 is required, so if missing we're using the subscriber name // Line1 is required, so if missing we're using the subscriber name,
// see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1
if (address != null && string.IsNullOrWhiteSpace(address.Line1)) if (address != null && string.IsNullOrWhiteSpace(address.Line1))
{ {

View File

@ -1341,9 +1341,7 @@ public class UserService : UserManager<User>, IUserService, IDisposable
var organizationsWithVerifiedUserEmailDomain = await _organizationRepository.GetByVerifiedUserEmailDomainAsync(userId); var organizationsWithVerifiedUserEmailDomain = await _organizationRepository.GetByVerifiedUserEmailDomainAsync(userId);
// Organizations must be enabled and able to have verified domains. // Organizations must be enabled and able to have verified domains.
// TODO: Replace "UseSso" with a new organization ability like "UseOrganizationDomains" (PM-11622). return organizationsWithVerifiedUserEmailDomain.Where(organization => organization is { Enabled: true, UseOrganizationDomains: true });
// Verified domains were tied to SSO, so we currently check the "UseSso" organization ability.
return organizationsWithVerifiedUserEmailDomain.Where(organization => organization is { Enabled: true, UseSso: true });
} }
/// <inheritdoc cref="IsLegacyUser(string)"/> /// <inheritdoc cref="IsLegacyUser(string)"/>

View File

@ -0,0 +1,19 @@
using Bit.Core.Tools.Entities;
namespace Bit.Core.Tools.Models.Data;
/// <summary>
/// This enum represents the possible results when attempting to access a <see cref="Send"/>.
/// </summary>
/// <member>name="Granted">Access is granted for the <see cref="Send"/>.</member>
/// <member>name="PasswordRequired">Access is denied, but a password is required to access the <see cref="Send"/>.
/// </member>
/// <member>name="PasswordInvalid">Access is denied due to an invalid password.</member>
/// <member>name="Denied">Access is denied for the <see cref="Send"/>.</member>
public enum SendAccessResult
{
Granted,
PasswordRequired,
PasswordInvalid,
Denied
}

View File

@ -0,0 +1,52 @@
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services;
namespace Bit.Core.Tools.SendFeatures.Commands;
public class AnonymousSendCommand : IAnonymousSendCommand
{
private readonly ISendRepository _sendRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPushNotificationService _pushNotificationService;
private readonly ISendAuthorizationService _sendAuthorizationService;
public AnonymousSendCommand(
ISendRepository sendRepository,
ISendFileStorageService sendFileStorageService,
IPushNotificationService pushNotificationService,
ISendAuthorizationService sendAuthorizationService
)
{
_sendRepository = sendRepository;
_sendFileStorageService = sendFileStorageService;
_pushNotificationService = pushNotificationService;
_sendAuthorizationService = sendAuthorizationService;
}
// Response: Send, password required, password invalid
public async Task<(string, SendAccessResult)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Can only get a download URL for a file type of Send");
}
var result = _sendAuthorizationService.SendCanBeAccessed(send, password);
if (!result.Equals(SendAccessResult.Granted))
{
return (null, result);
}
send.AccessCount++;
await _sendRepository.ReplaceAsync(send);
await _pushNotificationService.PushSyncSendUpdateAsync(send);
return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), result);
}
}

View File

@ -0,0 +1,21 @@
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Models.Data;
namespace Bit.Core.Tools.SendFeatures.Commands.Interfaces;
/// <summary>
/// AnonymousSendCommand interface provides methods for managing anonymous Sends.
/// </summary>
public interface IAnonymousSendCommand
{
/// <summary>
/// Gets the Send file download URL for a Send object.
/// </summary>
/// <param name="send"><see cref="Send" /> used to help get file download url and validate file</param>
/// <param name="fileId">FileId get file download url</param>
/// <param name="password">A hashed and base64-encoded password. This is compared with the send's password to authorize access.</param>
/// <returns>Async Task object with Tuple containing the string of download url and <see cref="SendAccessResult" />
/// to determine if the user can access send.
/// </returns>
Task<(string, SendAccessResult)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password);
}

View File

@ -0,0 +1,53 @@
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Models.Data;
namespace Bit.Core.Tools.SendFeatures.Commands.Interfaces;
/// <summary>
/// NonAnonymousSendCommand interface provides methods for managing non-anonymous Sends.
/// </summary>
public interface INonAnonymousSendCommand
{
/// <summary>
/// Saves a <see cref="Send" /> to the database.
/// </summary>
/// <param name="send"><see cref="Send" /> that will save to database</param>
/// <returns>Task completes as <see cref="Send" /> saves to the database</returns>
Task SaveSendAsync(Send send);
/// <summary>
/// Saves the <see cref="Send" /> and <see cref="SendFileData" /> to the database.
/// </summary>
/// <param name="send"><see cref="Send" /> that will save to the database</param>
/// <param name="data"><see cref="SendFileData" /> that will save to file storage</param>
/// <param name="fileLength">Length of file help with saving to file storage</param>
/// <returns>Task object for async operations with file upload url</returns>
Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength);
/// <summary>
/// Upload a file to an existing <see cref="Send" />.
/// </summary>
/// <param name="stream"><see cref="Stream" /> of file to be uploaded. The <see cref="Stream" /> position
/// will be set to 0 before uploading the file.</param>
/// <param name="send"><see cref="Send" /> used to help with uploading file</param>
/// <returns>Task completes after saving <see cref="Stream" /> and <see cref="Send" /> metadata to the file storage</returns>
Task UploadFileToExistingSendAsync(Stream stream, Send send);
/// <summary>
/// Deletes a <see cref="Send" /> from the database and file storage.
/// </summary>
/// <param name="send"><see cref="Send" /> is used to delete from database and file storage</param>
/// <returns>Task completes once <see cref="Send" /> has been deleted from database and file storage.</returns>
Task DeleteSendAsync(Send send);
/// <summary>
/// Stores the confirmed file size of a send; when the file size cannot be confirmed, the send is deleted.
/// </summary>
/// <param name="send">The <see cref="Send" /> this command acts upon</param>
/// <returns><see langword="true" /> when the file is confirmed, otherwise <see langword="false" /></returns>
/// <remarks>
/// When a file size cannot be confirmed, we assume we're working with a rogue client. The send is deleted out of
/// an abundance of caution.
/// </remarks>
Task<bool> ConfirmFileSize(Send send);
}

View File

@ -0,0 +1,180 @@
using System.Text.Json;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities;
namespace Bit.Core.Tools.SendFeatures.Commands;
public class NonAnonymousSendCommand : INonAnonymousSendCommand
{
private readonly ISendRepository _sendRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPushNotificationService _pushNotificationService;
private readonly ISendValidationService _sendValidationService;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
private readonly ISendCoreHelperService _sendCoreHelperService;
public NonAnonymousSendCommand(ISendRepository sendRepository,
ISendFileStorageService sendFileStorageService,
IPushNotificationService pushNotificationService,
ISendAuthorizationService sendAuthorizationService,
ISendValidationService sendValidationService,
IReferenceEventService referenceEventService,
ICurrentContext currentContext,
ISendCoreHelperService sendCoreHelperService)
{
_sendRepository = sendRepository;
_sendFileStorageService = sendFileStorageService;
_pushNotificationService = pushNotificationService;
_sendValidationService = sendValidationService;
_referenceEventService = referenceEventService;
_currentContext = currentContext;
_sendCoreHelperService = sendCoreHelperService;
}
public async Task SaveSendAsync(Send send)
{
// Make sure user can save Sends
await _sendValidationService.ValidateUserCanSaveAsync(send.UserId, send);
if (send.Id == default(Guid))
{
await _sendRepository.CreateAsync(send);
await _pushNotificationService.PushSyncSendCreateAsync(send);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
{
Id = send.UserId ?? default,
Type = ReferenceEventType.SendCreated,
Source = ReferenceEventSource.User,
SendType = send.Type,
MaxAccessCount = send.MaxAccessCount,
HasPassword = !string.IsNullOrWhiteSpace(send.Password),
SendHasNotes = send.Data?.Contains("Notes"),
ClientId = _currentContext.ClientId,
ClientVersion = _currentContext.ClientVersion
});
}
else
{
send.RevisionDate = DateTime.UtcNow;
await _sendRepository.UpsertAsync(send);
await _pushNotificationService.PushSyncSendUpdateAsync(send);
}
}
public async Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Send is not of type \"file\".");
}
if (fileLength < 1)
{
throw new BadRequestException("No file data.");
}
var storageBytesRemaining = await _sendValidationService.StorageRemainingForSendAsync(send);
if (storageBytesRemaining < fileLength)
{
throw new BadRequestException("Not enough storage available.");
}
var fileId = _sendCoreHelperService.SecureRandomString(32, useUpperCase: false, useSpecial: false);
try
{
data.Id = fileId;
data.Size = fileLength;
data.Validated = false;
send.Data = JsonSerializer.Serialize(data,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId);
}
catch
{
// Clean up since this is not transactional
await _sendFileStorageService.DeleteFileAsync(send, fileId);
throw;
}
}
public async Task UploadFileToExistingSendAsync(Stream stream, Send send)
{
if (stream.Position > 0)
{
stream.Position = 0;
}
if (send?.Data == null)
{
throw new BadRequestException("Send does not have file data");
}
if (send.Type != SendType.File)
{
throw new BadRequestException("Not a File Type Send.");
}
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
if (data.Validated)
{
throw new BadRequestException("File has already been uploaded.");
}
await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id);
if (!await ConfirmFileSize(send))
{
throw new BadRequestException("File received does not match expected file length.");
}
}
public async Task DeleteSendAsync(Send send)
{
await _sendRepository.DeleteAsync(send);
if (send.Type == Enums.SendType.File)
{
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
await _sendFileStorageService.DeleteFileAsync(send, data.Id);
}
await _pushNotificationService.PushSyncSendDeleteAsync(send);
}
public async Task<bool> ConfirmFileSize(Send send)
{
var fileData = JsonSerializer.Deserialize<SendFileData>(send.Data);
var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, SendFileSettingHelper.FILE_SIZE_LEEWAY);
if (!valid || realSize > SendFileSettingHelper.FILE_SIZE_LEEWAY)
{
// File reported differs in size from that promised. Must be a rogue client. Delete Send
await DeleteSendAsync(send);
return false;
}
// Update Send data if necessary
if (realSize != fileData.Size)
{
fileData.Size = realSize.Value;
}
fileData.Validated = true;
send.Data = JsonSerializer.Serialize(fileData,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return valid;
}
}

View File

@ -0,0 +1,18 @@
using Bit.Core.Tools.SendFeatures.Commands;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Tools.SendFeatures;
public static class SendServiceCollectionExtension
{
public static void AddSendServices(this IServiceCollection services)
{
services.AddScoped<INonAnonymousSendCommand, NonAnonymousSendCommand>();
services.AddScoped<IAnonymousSendCommand, AnonymousSendCommand>();
services.AddScoped<ISendAuthorizationService, SendAuthorizationService>();
services.AddScoped<ISendValidationService, SendValidationService>();
services.AddScoped<ISendCoreHelperService, SendCoreHelperService>();
}
}

View File

@ -0,0 +1,28 @@
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Models.Data;
namespace Bit.Core.Tools.Services;
/// <summary>
/// Send Authorization service is responsible for checking if a Send can be accessed.
/// </summary>
public interface ISendAuthorizationService
{
/// <summary>
/// Checks if a <see cref="Send" /> can be accessed while updating the <see cref="Send" />, pushing a notification, and sending a reference event.
/// </summary>
/// <param name="send"><see cref="Send" /> used to determine access</param>
/// <param name="password">A hashed and base64-encoded password. This is compared with the send's password to authorize access.</param>
/// <returns><see cref="SendAccessResult" /> will be returned to determine if the user can access send.
/// </returns>
Task<SendAccessResult> AccessAsync(Send send, string password);
SendAccessResult SendCanBeAccessed(Send send,
string password);
/// <summary>
/// Hashes the password using the password hasher.
/// </summary>
/// <param name="password">Password to be hashed</param>
/// <returns>Hashed password of the password given</returns>
string HashPassword(string password);
}

View File

@ -0,0 +1,17 @@
namespace Bit.Core.Tools.Services;
/// <summary>
/// This interface provides helper methods for generating secure random strings. Making
/// it easier to mock the service in unit tests.
/// </summary>
public interface ISendCoreHelperService
{
/// <summary>
/// Securely generates a random string of the specified length.
/// </summary>
/// <param name="length">Desired string length to be returned</param>
/// <param name="useUpperCase">Desired casing for the string</param>
/// <param name="useSpecial">Determines if special characters will be used in string</param>
/// <returns>A secure random string with the desired parameters</returns>
string SecureRandomString(int length, bool useUpperCase, bool useSpecial);
}

View File

@ -0,0 +1,71 @@
using Bit.Core.Enums;
using Bit.Core.Tools.Entities;
namespace Bit.Core.Tools.Services;
/// <summary>
/// Send File Storage Service is responsible for uploading, deleting, and validating files
/// whether they are in local storage or in cloud storage.
/// </summary>
public interface ISendFileStorageService
{
FileUploadType FileUploadType { get; }
/// <summary>
/// Uploads a new file to the storage.
/// </summary>
/// <param name="stream"><see cref="Stream" /> of the file</param>
/// <param name="send"><see cref="Send" /> for the file</param>
/// <param name="fileId">File id</param>
/// <returns>Task completes once <see cref="Stream" /> and <see cref="Send" /> have been saved to the database</returns>
Task UploadNewFileAsync(Stream stream, Send send, string fileId);
/// <summary>
/// Deletes a file from the storage.
/// </summary>
/// <param name="send"><see cref="Send" /> used to delete file</param>
/// <param name="fileId">File id of file to be deleted</param>
/// <returns>Task completes once <see cref="Send" /> has been deleted to the database</returns>
Task DeleteFileAsync(Send send, string fileId);
/// <summary>
/// Deletes all files for a specific organization.
/// </summary>
/// <param name="organizationId"><see cref="Guid" /> used to delete all files pertaining to organization</param>
/// <returns>Task completes after running code to delete files by organization id</returns>
Task DeleteFilesForOrganizationAsync(Guid organizationId);
/// <summary>
/// Deletes all files for a specific user.
/// </summary>
/// <param name="userId"><see cref="Guid" /> used to delete all files pertaining to user</param>
/// <returns>Task completes after running code to delete files by user id</returns>
Task DeleteFilesForUserAsync(Guid userId);
/// <summary>
/// Gets the download URL for a file.
/// </summary>
/// <param name="send"><see cref="Send" /> used to help get download url for file</param>
/// <param name="fileId">File id to help get download url for file</param>
/// <returns>Download url as a string</returns>
Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId);
/// <summary>
/// Gets the upload URL for a file.
/// </summary>
/// <param name="send"><see cref="Send" /> used to help get upload url for file </param>
/// <param name="fileId">File id to help get upload url for file</param>
/// <returns>File upload url as string</returns>
Task<string> GetSendFileUploadUrlAsync(Send send, string fileId);
/// <summary>
/// Validates the file size of a file in the storage.
/// </summary>
/// <param name="send"><see cref="Send" /> used to help validate file</param>
/// <param name="fileId">File id to identify which file to validate</param>
/// <param name="expectedFileSize">Expected file size of the file</param>
/// <param name="leeway">
/// Send file size tolerance in bytes. When an uploaded file's `expectedFileSize`
/// is outside of the leeway, the storage operation fails.
/// </param>
/// <throws>
/// ❌ Fill this in with an explanation of the error thrown when `leeway` is incorrect
/// </throws>
/// <returns>Task object for async operations with Tuple of boolean that determines if file was valid and long that
/// the actual file size of the file.
/// </returns>
Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway);
}

View File

@ -0,0 +1,35 @@
using Bit.Core.Tools.Entities;
namespace Bit.Core.Tools.Services;
public interface ISendValidationService
{
/// <summary>
/// Validates a file can be saved by specified user.
/// </summary>
/// <param name="userId"><see cref="Guid" /> needed to validate file for specific user</param>
/// <param name="send"><see cref="Send" /> needed to help validate file</param>
/// <returns>Task completes when a conditional statement has been met it will return out of the method or
/// throw a BadRequestException.
/// </returns>
Task ValidateUserCanSaveAsync(Guid? userId, Send send);
/// <summary>
/// Validates a file can be saved by specified user with different policy based on feature flag
/// </summary>
/// <param name="userId"><see cref="Guid" /> needed to validate file for specific user</param>
/// <param name="send"><see cref="Send" /> needed to help validate file</param>
/// <returns>Task completes when a conditional statement has been met it will return out of the method or
/// throw a BadRequestException.
/// </returns>
Task ValidateUserCanSaveAsync_vNext(Guid? userId, Send send);
/// <summary>
/// Calculates the remaining storage for a Send.
/// </summary>
/// <param name="send"><see cref="Send" /> needed to help calculate remaining storage</param>
/// <returns>Long with the remaining bytes for storage or will throw a BadRequestException if user cannot access
/// file or email is not verified.
/// </returns>
Task<long> StorageRemainingForSendAsync(Send send);
}

View File

@ -0,0 +1,101 @@
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Platform.Push;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Microsoft.AspNetCore.Identity;
namespace Bit.Core.Tools.Services;
public class SendAuthorizationService : ISendAuthorizationService
{
private readonly ISendRepository _sendRepository;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IPushNotificationService _pushNotificationService;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
public SendAuthorizationService(
ISendRepository sendRepository,
IPasswordHasher<User> passwordHasher,
IPushNotificationService pushNotificationService,
IReferenceEventService referenceEventService,
ICurrentContext currentContext)
{
_sendRepository = sendRepository;
_passwordHasher = passwordHasher;
_pushNotificationService = pushNotificationService;
_referenceEventService = referenceEventService;
_currentContext = currentContext;
}
public SendAccessResult SendCanBeAccessed(Send send,
string password)
{
var now = DateTime.UtcNow;
if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount ||
send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled ||
send.DeletionDate < now)
{
return SendAccessResult.Denied;
}
if (!string.IsNullOrWhiteSpace(send.Password))
{
if (string.IsNullOrWhiteSpace(password))
{
return SendAccessResult.PasswordRequired;
}
var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password);
if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded)
{
send.Password = HashPassword(password);
}
if (passwordResult == PasswordVerificationResult.Failed)
{
return SendAccessResult.PasswordInvalid;
}
}
return SendAccessResult.Granted;
}
public async Task<SendAccessResult> AccessAsync(Send sendToBeAccessed, string password)
{
var accessResult = SendCanBeAccessed(sendToBeAccessed, password);
if (!accessResult.Equals(SendAccessResult.Granted))
{
return accessResult;
}
if (sendToBeAccessed.Type != SendType.File)
{
// File sends are incremented during file download
sendToBeAccessed.AccessCount++;
}
await _sendRepository.ReplaceAsync(sendToBeAccessed);
await _pushNotificationService.PushSyncSendUpdateAsync(sendToBeAccessed);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
{
Id = sendToBeAccessed.UserId ?? default,
Type = ReferenceEventType.SendAccessed,
Source = ReferenceEventSource.User,
SendType = sendToBeAccessed.Type,
MaxAccessCount = sendToBeAccessed.MaxAccessCount,
HasPassword = !string.IsNullOrWhiteSpace(sendToBeAccessed.Password),
SendHasNotes = sendToBeAccessed.Data?.Contains("Notes"),
ClientId = _currentContext.ClientId,
ClientVersion = _currentContext.ClientVersion
});
return accessResult;
}
public string HashPassword(string password)
{
return _passwordHasher.HashPassword(new User(), password);
}
}

View File

@ -0,0 +1,12 @@
using Bit.Core.Utilities;
namespace Bit.Core.Tools.Services;
public class SendCoreHelperService : ISendCoreHelperService
{
public string SecureRandomString(int length, bool useUpperCase, bool useSpecial)
{
return CoreHelpers.SecureRandomString(length, upper: useUpperCase, special: useSpecial);
}
}

View File

@ -0,0 +1,26 @@
using Bit.Core.Tools.Entities;
namespace Bit.Core.Tools.SendFeatures;
/// <summary>
/// SendFileSettingHelper is a static class that provides constants and helper methods (if needed) for managing file
/// settings.
/// </summary>
public static class SendFileSettingHelper
{
/// <summary>
/// The leeway for the file size. This is the calculated 1 megabyte of cushion when doing comparisons of file sizes
/// within the system.
/// </summary>
public const long FILE_SIZE_LEEWAY = 1024L * 1024L; // 1MB
/// <summary>
/// The maximum file size for a file uploaded in a <see cref="Send" />. Units are calculated in bytes but
/// represent 501 megabytes. 1 megabyte is added for cushion to account for file size.
/// </summary>
public const long MAX_FILE_SIZE = Constants.FileSize501mb;
/// <summary>
/// String of the expected file size and to be used when needing to communicate the file size to the client/user.
/// </summary>
public const string MAX_FILE_SIZE_READABLE = "500 MB";
}

View File

@ -0,0 +1,142 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities;
using Bit.Core.Utilities;
namespace Bit.Core.Tools.Services;
public class SendValidationService : ISendValidationService
{
private readonly IUserRepository _userRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IPolicyService _policyService;
private readonly IFeatureService _featureService;
private readonly IUserService _userService;
private readonly GlobalSettings _globalSettings;
private readonly ICurrentContext _currentContext;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
public SendValidationService(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IPolicyService policyService,
IFeatureService featureService,
IUserService userService,
IPolicyRequirementQuery policyRequirementQuery,
GlobalSettings globalSettings,
ICurrentContext currentContext)
{
_userRepository = userRepository;
_organizationRepository = organizationRepository;
_policyService = policyService;
_featureService = featureService;
_userService = userService;
_policyRequirementQuery = policyRequirementQuery;
_globalSettings = globalSettings;
_currentContext = currentContext;
}
public async Task ValidateUserCanSaveAsync(Guid? userId, Send send)
{
if (_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements))
{
await ValidateUserCanSaveAsync_vNext(userId, send);
return;
}
if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true))
{
return;
}
var anyDisableSendPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(userId.Value,
PolicyType.DisableSend);
if (anyDisableSendPolicies)
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
if (send.HideEmail.GetValueOrDefault())
{
var sendOptionsPolicies = await _policyService.GetPoliciesApplicableToUserAsync(userId.Value, PolicyType.SendOptions);
if (sendOptionsPolicies.Any(p => CoreHelpers.LoadClassFromJsonData<SendOptionsPolicyData>(p.PolicyData)?.DisableHideEmail ?? false))
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
}
public async Task ValidateUserCanSaveAsync_vNext(Guid? userId, Send send)
{
if (!userId.HasValue)
{
return;
}
var disableSendRequirement = await _policyRequirementQuery.GetAsync<DisableSendPolicyRequirement>(userId.Value);
if (disableSendRequirement.DisableSend)
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
var sendOptionsRequirement = await _policyRequirementQuery.GetAsync<SendOptionsPolicyRequirement>(userId.Value);
if (sendOptionsRequirement.DisableHideEmail && send.HideEmail.GetValueOrDefault())
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
public async Task<long> StorageRemainingForSendAsync(Send send)
{
var storageBytesRemaining = 0L;
if (send.UserId.HasValue)
{
var user = await _userRepository.GetByIdAsync(send.UserId.Value);
if (!await _userService.CanAccessPremium(user))
{
throw new BadRequestException("You must have premium status to use file Sends.");
}
if (!user.EmailVerified)
{
throw new BadRequestException("You must confirm your email to use file Sends.");
}
if (user.Premium)
{
storageBytesRemaining = user.StorageBytesRemaining();
}
else
{
// Users that get access to file storage/premium from their organization get the default
// 1 GB max storage.
short limit = _globalSettings.SelfHosted ? (short)10240 : (short)1;
storageBytesRemaining = user.StorageBytesRemaining(limit);
}
}
else if (send.OrganizationId.HasValue)
{
var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
if (!org.MaxStorageGb.HasValue)
{
throw new BadRequestException("This organization cannot use file sends.");
}
storageBytesRemaining = org.StorageBytesRemaining();
}
return storageBytesRemaining;
}
}

View File

@ -1,16 +0,0 @@
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Models.Data;
namespace Bit.Core.Tools.Services;
public interface ISendService
{
Task DeleteSendAsync(Send send);
Task SaveSendAsync(Send send);
Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength);
Task UploadFileToExistingSendAsync(Stream stream, Send send);
Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password);
string HashPassword(string password);
Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password);
Task<bool> ValidateSendFile(Send send);
}

View File

@ -1,16 +0,0 @@
using Bit.Core.Enums;
using Bit.Core.Tools.Entities;
namespace Bit.Core.Tools.Services;
public interface ISendFileStorageService
{
FileUploadType FileUploadType { get; }
Task UploadNewFileAsync(Stream stream, Send send, string fileId);
Task DeleteFileAsync(Send send, string fileId);
Task DeleteFilesForOrganizationAsync(Guid organizationId);
Task DeleteFilesForUserAsync(Guid userId);
Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId);
Task<string> GetSendFileUploadUrlAsync(Send send, string fileId);
Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway);
}

View File

@ -1,383 +0,0 @@
using System.Text.Json;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Identity;
namespace Bit.Core.Tools.Services;
public class SendService : ISendService
{
public const long MAX_FILE_SIZE = Constants.FileSize501mb;
public const string MAX_FILE_SIZE_READABLE = "500 MB";
private readonly ISendRepository _sendRepository;
private readonly IUserRepository _userRepository;
private readonly IPolicyService _policyService;
private readonly IUserService _userService;
private readonly IOrganizationRepository _organizationRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IPushNotificationService _pushService;
private readonly IReferenceEventService _referenceEventService;
private readonly GlobalSettings _globalSettings;
private readonly ICurrentContext _currentContext;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IFeatureService _featureService;
private const long _fileSizeLeeway = 1024L * 1024L; // 1MB
public SendService(
ISendRepository sendRepository,
IUserRepository userRepository,
IUserService userService,
IOrganizationRepository organizationRepository,
ISendFileStorageService sendFileStorageService,
IPasswordHasher<User> passwordHasher,
IPushNotificationService pushService,
IReferenceEventService referenceEventService,
GlobalSettings globalSettings,
IPolicyService policyService,
ICurrentContext currentContext,
IPolicyRequirementQuery policyRequirementQuery,
IFeatureService featureService)
{
_sendRepository = sendRepository;
_userRepository = userRepository;
_userService = userService;
_policyService = policyService;
_organizationRepository = organizationRepository;
_sendFileStorageService = sendFileStorageService;
_passwordHasher = passwordHasher;
_pushService = pushService;
_referenceEventService = referenceEventService;
_globalSettings = globalSettings;
_currentContext = currentContext;
_policyRequirementQuery = policyRequirementQuery;
_featureService = featureService;
}
public async Task SaveSendAsync(Send send)
{
// Make sure user can save Sends
await ValidateUserCanSaveAsync(send.UserId, send);
if (send.Id == default(Guid))
{
await _sendRepository.CreateAsync(send);
await _pushService.PushSyncSendCreateAsync(send);
await RaiseReferenceEventAsync(send, ReferenceEventType.SendCreated);
}
else
{
send.RevisionDate = DateTime.UtcNow;
await _sendRepository.UpsertAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
}
}
public async Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Send is not of type \"file\".");
}
if (fileLength < 1)
{
throw new BadRequestException("No file data.");
}
var storageBytesRemaining = await StorageRemainingForSendAsync(send);
if (storageBytesRemaining < fileLength)
{
throw new BadRequestException("Not enough storage available.");
}
var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false);
try
{
data.Id = fileId;
data.Size = fileLength;
data.Validated = false;
send.Data = JsonSerializer.Serialize(data,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId);
}
catch
{
// Clean up since this is not transactional
await _sendFileStorageService.DeleteFileAsync(send, fileId);
throw;
}
}
public async Task UploadFileToExistingSendAsync(Stream stream, Send send)
{
if (send?.Data == null)
{
throw new BadRequestException("Send does not have file data");
}
if (send.Type != SendType.File)
{
throw new BadRequestException("Not a File Type Send.");
}
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
if (data.Validated)
{
throw new BadRequestException("File has already been uploaded.");
}
await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id);
if (!await ValidateSendFile(send))
{
throw new BadRequestException("File received does not match expected file length.");
}
}
public async Task<bool> ValidateSendFile(Send send)
{
var fileData = JsonSerializer.Deserialize<SendFileData>(send.Data);
var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, _fileSizeLeeway);
if (!valid || realSize > MAX_FILE_SIZE)
{
// File reported differs in size from that promised. Must be a rogue client. Delete Send
await DeleteSendAsync(send);
return false;
}
// Update Send data if necessary
if (realSize != fileData.Size)
{
fileData.Size = realSize.Value;
}
fileData.Validated = true;
send.Data = JsonSerializer.Serialize(fileData,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return valid;
}
public async Task DeleteSendAsync(Send send)
{
await _sendRepository.DeleteAsync(send);
if (send.Type == Enums.SendType.File)
{
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
await _sendFileStorageService.DeleteFileAsync(send, data.Id);
}
await _pushService.PushSyncSendDeleteAsync(send);
}
public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send,
string password)
{
var now = DateTime.UtcNow;
if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount ||
send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled ||
send.DeletionDate < now)
{
return (false, false, false);
}
if (!string.IsNullOrWhiteSpace(send.Password))
{
if (string.IsNullOrWhiteSpace(password))
{
return (false, true, false);
}
var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password);
if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded)
{
send.Password = HashPassword(password);
}
if (passwordResult == PasswordVerificationResult.Failed)
{
return (false, false, true);
}
}
return (true, false, false);
}
// Response: Send, password required, password invalid
public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Can only get a download URL for a file type of Send");
}
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
if (!grantAccess)
{
return (null, passwordRequired, passwordInvalid);
}
send.AccessCount++;
await _sendRepository.ReplaceAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false);
}
// Response: Send, password required, password invalid
public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password)
{
var send = await _sendRepository.GetByIdAsync(sendId);
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
if (!grantAccess)
{
return (null, passwordRequired, passwordInvalid);
}
// TODO: maybe move this to a simple ++ sproc?
if (send.Type != SendType.File)
{
// File sends are incremented during file download
send.AccessCount++;
}
await _sendRepository.ReplaceAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed);
return (send, false, false);
}
private async Task RaiseReferenceEventAsync(Send send, ReferenceEventType eventType)
{
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
{
Id = send.UserId ?? default,
Type = eventType,
Source = ReferenceEventSource.User,
SendType = send.Type,
MaxAccessCount = send.MaxAccessCount,
HasPassword = !string.IsNullOrWhiteSpace(send.Password),
SendHasNotes = send.Data?.Contains("Notes"),
ClientId = _currentContext.ClientId,
ClientVersion = _currentContext.ClientVersion
});
}
public string HashPassword(string password)
{
return _passwordHasher.HashPassword(new User(), password);
}
private async Task ValidateUserCanSaveAsync(Guid? userId, Send send)
{
if (_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements))
{
await ValidateUserCanSaveAsync_vNext(userId, send);
return;
}
if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true))
{
return;
}
var anyDisableSendPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(userId.Value,
PolicyType.DisableSend);
if (anyDisableSendPolicies)
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
if (send.HideEmail.GetValueOrDefault())
{
var sendOptionsPolicies = await _policyService.GetPoliciesApplicableToUserAsync(userId.Value, PolicyType.SendOptions);
if (sendOptionsPolicies.Any(p => CoreHelpers.LoadClassFromJsonData<SendOptionsPolicyData>(p.PolicyData)?.DisableHideEmail ?? false))
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
}
private async Task ValidateUserCanSaveAsync_vNext(Guid? userId, Send send)
{
if (!userId.HasValue)
{
return;
}
var disableSendRequirement = await _policyRequirementQuery.GetAsync<DisableSendPolicyRequirement>(userId.Value);
if (disableSendRequirement.DisableSend)
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
var sendOptionsRequirement = await _policyRequirementQuery.GetAsync<SendOptionsPolicyRequirement>(userId.Value);
if (sendOptionsRequirement.DisableHideEmail && send.HideEmail.GetValueOrDefault())
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
private async Task<long> StorageRemainingForSendAsync(Send send)
{
var storageBytesRemaining = 0L;
if (send.UserId.HasValue)
{
var user = await _userRepository.GetByIdAsync(send.UserId.Value);
if (!await _userService.CanAccessPremium(user))
{
throw new BadRequestException("You must have premium status to use file Sends.");
}
if (!user.EmailVerified)
{
throw new BadRequestException("You must confirm your email to use file Sends.");
}
if (user.Premium)
{
storageBytesRemaining = user.StorageBytesRemaining();
}
else
{
// Users that get access to file storage/premium from their organization get the default
// 1 GB max storage.
storageBytesRemaining = user.StorageBytesRemaining(
_globalSettings.SelfHosted ? (short)10240 : (short)1);
}
}
else if (send.OrganizationId.HasValue)
{
var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
if (!org.MaxStorageGb.HasValue)
{
throw new BadRequestException("This organization cannot use file sends.");
}
storageBytesRemaining = org.StorageBytesRemaining();
}
return storageBytesRemaining;
}
}

View File

@ -64,12 +64,6 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator<ResourceOwner
public async Task ValidateAsync(ResourceOwnerPasswordValidationContext context) public async Task ValidateAsync(ResourceOwnerPasswordValidationContext context)
{ {
if (!AuthEmailHeaderIsValid(context))
{
context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant,
"Auth-Email header invalid.");
return;
}
var user = await _userManager.FindByEmailAsync(context.UserName.ToLowerInvariant()); var user = await _userManager.FindByEmailAsync(context.UserName.ToLowerInvariant());
// We want to keep this device around incase the device is new for the user // We want to keep this device around incase the device is new for the user
@ -168,29 +162,4 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator<ResourceOwner
return context.Result.Subject; return context.Result.Subject;
} }
private bool AuthEmailHeaderIsValid(ResourceOwnerPasswordValidationContext context)
{
if (_currentContext.HttpContext.Request.Headers.TryGetValue("Auth-Email", out var authEmailHeader))
{
try
{
var authEmailDecoded = CoreHelpers.Base64UrlDecodeString(authEmailHeader);
if (authEmailDecoded != context.UserName)
{
return false;
}
}
catch (Exception e) when (e is InvalidOperationException || e is FormatException)
{
// Invalid B64 encoding
return false;
}
}
else
{
return false;
}
return true;
}
} }

View File

@ -22,6 +22,7 @@ public class UserDecryptionOptionsBuilder : IUserDecryptionOptionsBuilder
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly IDeviceRepository _deviceRepository; private readonly IDeviceRepository _deviceRepository;
private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly ILoginApprovingClientTypes _loginApprovingClientTypes;
private UserDecryptionOptions _options = new UserDecryptionOptions(); private UserDecryptionOptions _options = new UserDecryptionOptions();
private User? _user; private User? _user;
@ -31,12 +32,14 @@ public class UserDecryptionOptionsBuilder : IUserDecryptionOptionsBuilder
public UserDecryptionOptionsBuilder( public UserDecryptionOptionsBuilder(
ICurrentContext currentContext, ICurrentContext currentContext,
IDeviceRepository deviceRepository, IDeviceRepository deviceRepository,
IOrganizationUserRepository organizationUserRepository IOrganizationUserRepository organizationUserRepository,
ILoginApprovingClientTypes loginApprovingClientTypes
) )
{ {
_currentContext = currentContext; _currentContext = currentContext;
_deviceRepository = deviceRepository; _deviceRepository = deviceRepository;
_organizationUserRepository = organizationUserRepository; _organizationUserRepository = organizationUserRepository;
_loginApprovingClientTypes = loginApprovingClientTypes;
} }
public IUserDecryptionOptionsBuilder ForUser(User user) public IUserDecryptionOptionsBuilder ForUser(User user)
@ -119,8 +122,7 @@ public class UserDecryptionOptionsBuilder : IUserDecryptionOptionsBuilder
// Checks if the current user has any devices that are capable of approving login with device requests except for // Checks if the current user has any devices that are capable of approving login with device requests except for
// their current device. // their current device.
// NOTE: this doesn't check for if the users have configured the devices to be capable of approving requests as that is a client side setting. // NOTE: this doesn't check for if the users have configured the devices to be capable of approving requests as that is a client side setting.
hasLoginApprovingDevice = allDevices hasLoginApprovingDevice = allDevices.Any(d => d.Identifier != _device.Identifier && _loginApprovingClientTypes.TypesThatCanApprove.Contains(DeviceTypes.ToClientType(d.Type)));
.Any(d => d.Identifier != _device.Identifier && LoginApprovingClientTypes.TypesThatCanApprove.Contains(DeviceTypes.ToClientType(d.Type)));
} }
// Determine if user has manage reset password permission as post sso logic requires it for forcing users with this permission to set a MP // Determine if user has manage reset password permission as post sso logic requires it for forcing users with this permission to set a MP

View File

@ -6,5 +6,12 @@ public class RegisterFinishResponseModel : ResponseModel
{ {
public RegisterFinishResponseModel() public RegisterFinishResponseModel()
: base("registerFinish") : base("registerFinish")
{ } {
// We are setting this to an empty string so that old mobile clients don't break, as they reqiure a non-null value.
// This will be cleaned up in https://bitwarden.atlassian.net/browse/PM-21720.
CaptchaBypassToken = string.Empty;
}
public string CaptchaBypassToken { get; set; }
} }

View File

@ -1,22 +1,39 @@
using Bit.Core.Enums; using Bit.Core;
using Bit.Core.Enums;
using Bit.Core.Services;
namespace Bit.Identity.Utilities; namespace Bit.Identity.Utilities;
public static class LoginApprovingClientTypes public interface ILoginApprovingClientTypes
{ {
private static readonly IReadOnlyCollection<ClientType> _clientTypesThatCanApprove; IReadOnlyCollection<ClientType> TypesThatCanApprove { get; }
}
static LoginApprovingClientTypes() public class LoginApprovingClientTypes : ILoginApprovingClientTypes
{
public LoginApprovingClientTypes(
IFeatureService featureService)
{ {
var clientTypes = new List<ClientType> if (featureService.IsEnabled(FeatureFlagKeys.BrowserExtensionLoginApproval))
{ {
ClientType.Desktop, TypesThatCanApprove = new List<ClientType>
ClientType.Mobile, {
ClientType.Web, ClientType.Desktop,
ClientType.Browser, ClientType.Mobile,
}; ClientType.Web,
_clientTypesThatCanApprove = clientTypes.AsReadOnly(); ClientType.Browser,
};
}
else
{
TypesThatCanApprove = new List<ClientType>
{
ClientType.Desktop,
ClientType.Mobile,
ClientType.Web,
};
}
} }
public static IReadOnlyCollection<ClientType> TypesThatCanApprove => _clientTypesThatCanApprove; public IReadOnlyCollection<ClientType> TypesThatCanApprove { get; }
} }

View File

@ -23,6 +23,7 @@ public static class ServiceCollectionExtensions
services.AddTransient<IUserDecryptionOptionsBuilder, UserDecryptionOptionsBuilder>(); services.AddTransient<IUserDecryptionOptionsBuilder, UserDecryptionOptionsBuilder>();
services.AddTransient<IDeviceValidator, DeviceValidator>(); services.AddTransient<IDeviceValidator, DeviceValidator>();
services.AddTransient<ITwoFactorAuthenticationValidator, TwoFactorAuthenticationValidator>(); services.AddTransient<ITwoFactorAuthenticationValidator, TwoFactorAuthenticationValidator>();
services.AddTransient<ILoginApprovingClientTypes, LoginApprovingClientTypes>();
var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalIdentity); var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalIdentity);
var identityServerBuilder = services var identityServerBuilder = services

View File

@ -43,6 +43,7 @@ using Bit.Core.Settings;
using Bit.Core.Tokens; using Bit.Core.Tokens;
using Bit.Core.Tools.ImportFeatures; using Bit.Core.Tools.ImportFeatures;
using Bit.Core.Tools.ReportFeatures; using Bit.Core.Tools.ReportFeatures;
using Bit.Core.Tools.SendFeatures;
using Bit.Core.Tools.Services; using Bit.Core.Tools.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Bit.Core.Vault; using Bit.Core.Vault;
@ -123,7 +124,7 @@ public static class ServiceCollectionExtensions
services.AddScoped<ISsoConfigService, SsoConfigService>(); services.AddScoped<ISsoConfigService, SsoConfigService>();
services.AddScoped<IAuthRequestService, AuthRequestService>(); services.AddScoped<IAuthRequestService, AuthRequestService>();
services.AddScoped<IDuoUniversalTokenService, DuoUniversalTokenService>(); services.AddScoped<IDuoUniversalTokenService, DuoUniversalTokenService>();
services.AddScoped<ISendService, SendService>(); services.AddScoped<ISendAuthorizationService, SendAuthorizationService>();
services.AddLoginServices(); services.AddLoginServices();
services.AddScoped<IOrganizationDomainService, OrganizationDomainService>(); services.AddScoped<IOrganizationDomainService, OrganizationDomainService>();
services.AddVaultServices(); services.AddVaultServices();
@ -132,6 +133,7 @@ public static class ServiceCollectionExtensions
services.AddNotificationCenterServices(); services.AddNotificationCenterServices();
services.AddPlatformServices(); services.AddPlatformServices();
services.AddImportServices(); services.AddImportServices();
services.AddSendServices();
} }
public static void AddTokenizers(this IServiceCollection services) public static void AddTokenizers(this IServiceCollection services)

View File

@ -23,11 +23,11 @@ public class SendRotationValidatorTests
public async Task ValidateAsync_Success() public async Task ValidateAsync_Success()
{ {
// Arrange // Arrange
var sendService = Substitute.For<ISendService>(); var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
var sendRepository = Substitute.For<ISendRepository>(); var sendRepository = Substitute.For<ISendRepository>();
var sut = new SendRotationValidator( var sut = new SendRotationValidator(
sendService, sendAuthorizationService,
sendRepository sendRepository
); );
@ -52,11 +52,11 @@ public class SendRotationValidatorTests
public async Task ValidateAsync_SendNotReturnedFromRepository_NotIncludedInOutput() public async Task ValidateAsync_SendNotReturnedFromRepository_NotIncludedInOutput()
{ {
// Arrange // Arrange
var sendService = Substitute.For<ISendService>(); var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
var sendRepository = Substitute.For<ISendRepository>(); var sendRepository = Substitute.For<ISendRepository>();
var sut = new SendRotationValidator( var sut = new SendRotationValidator(
sendService, sendAuthorizationService,
sendRepository sendRepository
); );
@ -76,11 +76,11 @@ public class SendRotationValidatorTests
public async Task ValidateAsync_InputMissingUserSend_Throws() public async Task ValidateAsync_InputMissingUserSend_Throws()
{ {
// Arrange // Arrange
var sendService = Substitute.For<ISendService>(); var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
var sendRepository = Substitute.For<ISendRepository>(); var sendRepository = Substitute.For<ISendRepository>();
var sut = new SendRotationValidator( var sut = new SendRotationValidator(
sendService, sendAuthorizationService,
sendRepository sendRepository
); );

View File

@ -10,7 +10,9 @@ using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Tools.Entities; using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums; using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories; using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services; using Bit.Core.Tools.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
@ -26,7 +28,9 @@ public class SendsControllerTests : IDisposable
private readonly GlobalSettings _globalSettings; private readonly GlobalSettings _globalSettings;
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly ISendRepository _sendRepository; private readonly ISendRepository _sendRepository;
private readonly ISendService _sendService; private readonly INonAnonymousSendCommand _nonAnonymousSendCommand;
private readonly IAnonymousSendCommand _anonymousSendCommand;
private readonly ISendAuthorizationService _sendAuthorizationService;
private readonly ISendFileStorageService _sendFileStorageService; private readonly ISendFileStorageService _sendFileStorageService;
private readonly ILogger<SendsController> _logger; private readonly ILogger<SendsController> _logger;
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
@ -35,7 +39,9 @@ public class SendsControllerTests : IDisposable
{ {
_userService = Substitute.For<IUserService>(); _userService = Substitute.For<IUserService>();
_sendRepository = Substitute.For<ISendRepository>(); _sendRepository = Substitute.For<ISendRepository>();
_sendService = Substitute.For<ISendService>(); _nonAnonymousSendCommand = Substitute.For<INonAnonymousSendCommand>();
_anonymousSendCommand = Substitute.For<IAnonymousSendCommand>();
_sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
_sendFileStorageService = Substitute.For<ISendFileStorageService>(); _sendFileStorageService = Substitute.For<ISendFileStorageService>();
_globalSettings = new GlobalSettings(); _globalSettings = new GlobalSettings();
_logger = Substitute.For<ILogger<SendsController>>(); _logger = Substitute.For<ILogger<SendsController>>();
@ -44,7 +50,9 @@ public class SendsControllerTests : IDisposable
_sut = new SendsController( _sut = new SendsController(
_sendRepository, _sendRepository,
_userService, _userService,
_sendService, _sendAuthorizationService,
_anonymousSendCommand,
_nonAnonymousSendCommand,
_sendFileStorageService, _sendFileStorageService,
_logger, _logger,
_globalSettings, _globalSettings,
@ -68,7 +76,8 @@ public class SendsControllerTests : IDisposable
send.Data = JsonSerializer.Serialize(new Dictionary<string, string>()); send.Data = JsonSerializer.Serialize(new Dictionary<string, string>());
send.HideEmail = true; send.HideEmail = true;
_sendService.AccessAsync(id, null).Returns((send, false, false)); _sendRepository.GetByIdAsync(Arg.Any<Guid>()).Returns(send);
_sendAuthorizationService.AccessAsync(send, null).Returns(SendAccessResult.Granted);
_userService.GetUserByIdAsync(Arg.Any<Guid>()).Returns(user); _userService.GetUserByIdAsync(Arg.Any<Guid>()).Returns(user);
var request = new SendAccessRequestModel(); var request = new SendAccessRequestModel();

View File

@ -34,11 +34,11 @@ public class SendRequestModelTests
Type = SendType.Text, Type = SendType.Text,
}; };
var sendService = Substitute.For<ISendService>(); var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
sendService.HashPassword(Arg.Any<string>()) sendAuthorizationService.HashPassword(Arg.Any<string>())
.Returns((info) => $"hashed_{(string)info[0]}"); .Returns((info) => $"hashed_{(string)info[0]}");
var send = sendRequest.ToSend(Guid.NewGuid(), sendService); var send = sendRequest.ToSend(Guid.NewGuid(), sendAuthorizationService);
Assert.Equal(deletionDate, send.DeletionDate); Assert.Equal(deletionDate, send.DeletionDate);
Assert.False(send.Disabled); Assert.False(send.Disabled);

View File

@ -25,13 +25,13 @@ public class GetOrganizationUsersClaimedStatusQueryTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetUsersOrganizationManagementStatusAsync_WithUseSsoEnabled_Success( public async Task GetUsersOrganizationManagementStatusAsync_WithUseOrganizationDomainsEnabled_Success(
Organization organization, Organization organization,
ICollection<OrganizationUser> usersWithClaimedDomain, ICollection<OrganizationUser> usersWithClaimedDomain,
SutProvider<GetOrganizationUsersClaimedStatusQuery> sutProvider) SutProvider<GetOrganizationUsersClaimedStatusQuery> sutProvider)
{ {
organization.Enabled = true; organization.Enabled = true;
organization.UseSso = true; organization.UseOrganizationDomains = true;
var userIdWithoutClaimedDomain = Guid.NewGuid(); var userIdWithoutClaimedDomain = Guid.NewGuid();
var userIdsToCheck = usersWithClaimedDomain.Select(u => u.Id).Concat(new List<Guid> { userIdWithoutClaimedDomain }).ToList(); var userIdsToCheck = usersWithClaimedDomain.Select(u => u.Id).Concat(new List<Guid> { userIdWithoutClaimedDomain }).ToList();
@ -51,13 +51,13 @@ public class GetOrganizationUsersClaimedStatusQueryTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetUsersOrganizationManagementStatusAsync_WithUseSsoDisabled_ReturnsAllFalse( public async Task GetUsersOrganizationManagementStatusAsync_WithUseOrganizationDomainsDisabled_ReturnsAllFalse(
Organization organization, Organization organization,
ICollection<OrganizationUser> usersWithClaimedDomain, ICollection<OrganizationUser> usersWithClaimedDomain,
SutProvider<GetOrganizationUsersClaimedStatusQuery> sutProvider) SutProvider<GetOrganizationUsersClaimedStatusQuery> sutProvider)
{ {
organization.Enabled = true; organization.Enabled = true;
organization.UseSso = false; organization.UseOrganizationDomains = false;
var userIdWithoutClaimedDomain = Guid.NewGuid(); var userIdWithoutClaimedDomain = Guid.NewGuid();
var userIdsToCheck = usersWithClaimedDomain.Select(u => u.Id).Concat(new List<Guid> { userIdWithoutClaimedDomain }).ToList(); var userIdsToCheck = usersWithClaimedDomain.Select(u => u.Id).Concat(new List<Guid> { userIdWithoutClaimedDomain }).ToList();

View File

@ -0,0 +1,169 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Models.Data;
using Bit.Core.Models.StaticStore;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Organizations.OrganizationSignUp;
[SutProviderCustomize]
public class ProviderClientOrganizationSignUpCommandTests
{
[Theory]
[BitAutoData(PlanType.TeamsAnnually)]
[BitAutoData(PlanType.TeamsMonthly)]
[BitAutoData(PlanType.EnterpriseAnnually)]
[BitAutoData(PlanType.EnterpriseMonthly)]
public async Task SignupClientAsync_ValidParameters_CreatesOrganizationSuccessfully(
PlanType planType,
OrganizationSignup signup,
string collectionName,
SutProvider<ProviderClientOrganizationSignUpCommand> sutProvider)
{
signup.Plan = planType;
signup.AdditionalSeats = 15;
signup.CollectionName = collectionName;
var plan = StaticStore.GetPlan(signup.Plan);
sutProvider.GetDependency<IPricingClient>()
.GetPlanOrThrow(signup.Plan)
.Returns(plan);
var result = await sutProvider.Sut.SignUpClientOrganizationAsync(signup);
Assert.NotNull(result.Organization);
Assert.NotNull(result.DefaultCollection);
Assert.Equal(collectionName, result.DefaultCollection.Name);
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.CreateAsync(
Arg.Is<Organization>(o =>
o.Name == signup.Name &&
o.BillingEmail == signup.BillingEmail &&
o.PlanType == plan.Type &&
o.Seats == signup.AdditionalSeats &&
o.MaxCollections == plan.PasswordManager.MaxCollections &&
o.UsePasswordManager == true &&
o.UseSecretsManager == false &&
o.Status == OrganizationStatusType.Created
)
);
await sutProvider.GetDependency<IReferenceEventService>()
.Received(1)
.RaiseEventAsync(Arg.Is<ReferenceEvent>(referenceEvent =>
referenceEvent.Type == ReferenceEventType.Signup &&
referenceEvent.PlanName == plan.Name &&
referenceEvent.PlanType == plan.Type &&
referenceEvent.Seats == result.Organization.Seats &&
referenceEvent.Storage == result.Organization.MaxStorageGb &&
referenceEvent.SignupInitiationPath == signup.InitiationPath
));
await sutProvider.GetDependency<ICollectionRepository>()
.Received(1)
.CreateAsync(
Arg.Is<Collection>(c =>
c.Name == collectionName &&
c.OrganizationId == result.Organization.Id
),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>()
);
await sutProvider.GetDependency<IOrganizationApiKeyRepository>()
.Received(1)
.CreateAsync(
Arg.Is<OrganizationApiKey>(k =>
k.OrganizationId == result.Organization.Id &&
k.Type == OrganizationApiKeyType.Default
)
);
await sutProvider.GetDependency<IApplicationCacheService>()
.Received(1)
.UpsertOrganizationAbilityAsync(Arg.Is<Organization>(o => o.Id == result.Organization.Id));
}
[Theory]
[BitAutoData]
public async Task SignupClientAsync_NullPlan_ThrowsBadRequestException(
OrganizationSignup signup,
SutProvider<ProviderClientOrganizationSignUpCommand> sutProvider)
{
sutProvider.GetDependency<IPricingClient>()
.GetPlanOrThrow(signup.Plan)
.Returns((Plan)null);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.SignUpClientOrganizationAsync(signup));
Assert.Contains(ProviderClientOrganizationSignUpCommand.PlanNullErrorMessage, exception.Message);
}
[Theory]
[BitAutoData]
public async Task SignupClientAsync_NegativeAdditionalSeats_ThrowsBadRequestException(
OrganizationSignup signup,
SutProvider<ProviderClientOrganizationSignUpCommand> sutProvider)
{
signup.Plan = PlanType.TeamsMonthly;
signup.AdditionalSeats = -5;
var plan = StaticStore.GetPlan(signup.Plan);
sutProvider.GetDependency<IPricingClient>()
.GetPlanOrThrow(signup.Plan)
.Returns(plan);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.SignUpClientOrganizationAsync(signup));
Assert.Contains(ProviderClientOrganizationSignUpCommand.AdditionalSeatsNegativeErrorMessage, exception.Message);
}
[Theory]
[BitAutoData(PlanType.TeamsAnnually)]
public async Task SignupClientAsync_WhenExceptionIsThrown_CleanupIsPerformed(
PlanType planType,
OrganizationSignup signup,
SutProvider<ProviderClientOrganizationSignUpCommand> sutProvider)
{
signup.Plan = planType;
var plan = StaticStore.GetPlan(signup.Plan);
sutProvider.GetDependency<IPricingClient>()
.GetPlanOrThrow(signup.Plan)
.Returns(plan);
sutProvider.GetDependency<IOrganizationApiKeyRepository>()
.When(x => x.CreateAsync(Arg.Any<OrganizationApiKey>()))
.Do(_ => throw new Exception());
var thrownException = await Assert.ThrowsAsync<Exception>(
() => sutProvider.Sut.SignUpClientOrganizationAsync(signup));
await sutProvider.GetDependency<IOrganizationRepository>()
.Received(1)
.DeleteAsync(Arg.Is<Organization>(o => o.Name == signup.Name));
await sutProvider.GetDependency<IApplicationCacheService>()
.Received(1)
.DeleteOrganizationAbilityAsync(Arg.Any<Guid>());
}
}

View File

@ -9,7 +9,6 @@ using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Models.Business; using Bit.Core.Models.Business;
@ -177,47 +176,6 @@ public class OrganizationServiceTests
referenceEvent.Users == expectedNewUsersCount)); referenceEvent.Users == expectedNewUsersCount));
} }
[Theory, BitAutoData]
public async Task SignupClientAsync_Succeeds(
OrganizationSignup signup,
SutProvider<OrganizationService> sutProvider)
{
signup.Plan = PlanType.TeamsMonthly;
var plan = StaticStore.GetPlan(signup.Plan);
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(signup.Plan).Returns(plan);
var (organization, _, _) = await sutProvider.Sut.SignupClientAsync(signup);
await sutProvider.GetDependency<IOrganizationRepository>().Received(1).CreateAsync(Arg.Is<Organization>(org =>
org.Id == organization.Id &&
org.Name == signup.Name &&
org.Plan == plan.Name &&
org.PlanType == plan.Type &&
org.UsePolicies == plan.HasPolicies &&
org.PublicKey == signup.PublicKey &&
org.PrivateKey == signup.PrivateKey &&
org.UseSecretsManager == false));
await sutProvider.GetDependency<IOrganizationApiKeyRepository>().Received(1)
.CreateAsync(Arg.Is<OrganizationApiKey>(orgApiKey =>
orgApiKey.OrganizationId == organization.Id));
await sutProvider.GetDependency<IApplicationCacheService>().Received(1)
.UpsertOrganizationAbilityAsync(organization);
await sutProvider.GetDependency<IOrganizationUserRepository>().DidNotReceiveWithAnyArgs().CreateAsync(default);
await sutProvider.GetDependency<ICollectionRepository>().Received(1)
.CreateAsync(Arg.Is<Collection>(c => c.Name == signup.CollectionName && c.OrganizationId == organization.Id), null, null);
await sutProvider.GetDependency<IReferenceEventService>().Received(1).RaiseEventAsync(Arg.Is<ReferenceEvent>(
re =>
re.Type == ReferenceEventType.Signup &&
re.PlanType == plan.Type));
}
[Theory] [Theory]
[OrganizationInviteCustomize(InviteeUserType = OrganizationUserType.User, [OrganizationInviteCustomize(InviteeUserType = OrganizationUserType.User,
InvitorUserType = OrganizationUserType.Owner), OrganizationCustomize, BitAutoData] InvitorUserType = OrganizationUserType.Owner), OrganizationCustomize, BitAutoData]

File diff suppressed because it is too large Load Diff

View File

@ -3,14 +3,11 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Caches; using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants; using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Test.Billing.Tax.Services;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
using Braintree; using Braintree;
@ -195,7 +192,7 @@ public class SubscriberServiceTests
await stripeAdapter await stripeAdapter
.DidNotReceiveWithAnyArgs() .DidNotReceiveWithAnyArgs()
.SubscriptionCancelAsync(Arg.Any<string>(), Arg.Any<SubscriptionCancelOptions>()); ; .SubscriptionCancelAsync(Arg.Any<string>(), Arg.Any<SubscriptionCancelOptions>());
} }
#endregion #endregion
@ -1029,7 +1026,7 @@ public class SubscriberServiceTests
stripeAdapter stripeAdapter
.PaymentMethodListAutoPagingAsync(Arg.Any<PaymentMethodListOptions>()) .PaymentMethodListAutoPagingAsync(Arg.Any<PaymentMethodListOptions>())
.Returns(GetPaymentMethodsAsync(new List<Stripe.PaymentMethod>())); .Returns(GetPaymentMethodsAsync(new List<PaymentMethod>()));
await sutProvider.Sut.RemovePaymentSource(organization); await sutProvider.Sut.RemovePaymentSource(organization);
@ -1061,7 +1058,7 @@ public class SubscriberServiceTests
stripeAdapter stripeAdapter
.PaymentMethodListAutoPagingAsync(Arg.Any<PaymentMethodListOptions>()) .PaymentMethodListAutoPagingAsync(Arg.Any<PaymentMethodListOptions>())
.Returns(GetPaymentMethodsAsync(new List<Stripe.PaymentMethod> .Returns(GetPaymentMethodsAsync(new List<PaymentMethod>
{ {
new () new ()
{ {
@ -1086,8 +1083,8 @@ public class SubscriberServiceTests
.PaymentMethodDetachAsync(cardId); .PaymentMethodDetachAsync(cardId);
} }
private static async IAsyncEnumerable<Stripe.PaymentMethod> GetPaymentMethodsAsync( private static async IAsyncEnumerable<PaymentMethod> GetPaymentMethodsAsync(
IEnumerable<Stripe.PaymentMethod> paymentMethods) IEnumerable<PaymentMethod> paymentMethods)
{ {
foreach (var paymentMethod in paymentMethods) foreach (var paymentMethod in paymentMethods)
{ {
@ -1598,14 +1595,22 @@ public class SubscriberServiceTests
City = "Example Town", City = "Example Town",
State = "NY" State = "NY"
}, },
TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] },
Subscriptions = new StripeList<Subscription>
{
Data = [
new Subscription
{
Id = provider.GatewaySubscriptionId,
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
}
]
}
}); });
var subscription = new Subscription { Items = new StripeList<SubscriptionItem>() }; var subscription = new Subscription { Items = new StripeList<SubscriptionItem>() };
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(Arg.Any<string>()) sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(Arg.Any<string>())
.Returns(subscription); .Returns(subscription);
sutProvider.GetDependency<IAutomaticTaxFactory>().CreateAsync(Arg.Any<AutomaticTaxFactoryParameters>())
.Returns(new FakeAutomaticTaxStrategy(true));
await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation);
@ -1623,6 +1628,98 @@ public class SubscriberServiceTests
await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is<TaxIdCreateOptions>( await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is<TaxIdCreateOptions>(
options => options.Type == "us_ein" && options => options.Type == "us_ein" &&
options.Value == taxInformation.TaxId)); options.Value == taxInformation.TaxId));
await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true));
}
[Theory, BitAutoData]
public async Task UpdateTaxInformation_NonUser_ReverseCharge_MakesCorrectInvocations(
Provider provider,
SutProvider<SubscriberService> sutProvider)
{
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } };
stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is<CustomerGetOptions>(
options => options.Expand.Contains("tax_ids"))).Returns(customer);
var taxInformation = new TaxInformation(
"CA",
"12345",
"123456789",
"us_ein",
"123 Example St.",
null,
"Example Town",
"NY");
sutProvider.GetDependency<IStripeAdapter>()
.CustomerUpdateAsync(
Arg.Is<string>(p => p == provider.GatewayCustomerId),
Arg.Is<CustomerUpdateOptions>(options =>
options.Address.Country == "CA" &&
options.Address.PostalCode == "12345" &&
options.Address.Line1 == "123 Example St." &&
options.Address.Line2 == null &&
options.Address.City == "Example Town" &&
options.Address.State == "NY"))
.Returns(new Customer
{
Id = provider.GatewayCustomerId,
Address = new Address
{
Country = "CA",
PostalCode = "12345",
Line1 = "123 Example St.",
Line2 = null,
City = "Example Town",
State = "NY"
},
TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] },
Subscriptions = new StripeList<Subscription>
{
Data = [
new Subscription
{
Id = provider.GatewaySubscriptionId,
CustomerId = provider.GatewayCustomerId,
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
}
]
}
});
var subscription = new Subscription { Items = new StripeList<SubscriptionItem>() };
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(Arg.Any<string>())
.Returns(subscription);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation);
await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(
options =>
options.Address.Country == taxInformation.Country &&
options.Address.PostalCode == taxInformation.PostalCode &&
options.Address.Line1 == taxInformation.Line1 &&
options.Address.Line2 == taxInformation.Line2 &&
options.Address.City == taxInformation.City &&
options.Address.State == taxInformation.State));
await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1");
await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is<TaxIdCreateOptions>(
options => options.Type == "us_ein" &&
options.Value == taxInformation.TaxId));
await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId,
Arg.Is<CustomerUpdateOptions>(options => options.TaxExempt == StripeConstants.TaxExempt.Reverse));
await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true));
} }
#endregion #endregion

View File

@ -28,7 +28,10 @@ public static class OrganizationLicenseFileFixtures
private const string Version15 = private const string Version15 =
"{\n 'LicenseKey': 'myLicenseKey',\n 'InstallationId': '78900000-0000-0000-0000-000000000123',\n 'Id': '12300000-0000-0000-0000-000000000456',\n 'Name': 'myOrg',\n 'BillingEmail': 'myBillingEmail',\n 'BusinessName': 'myBusinessName',\n 'Enabled': true,\n 'Plan': 'myPlan',\n 'PlanType': 11,\n 'Seats': 10,\n 'MaxCollections': 2,\n 'UsePolicies': true,\n 'UseSso': true,\n 'UseKeyConnector': true,\n 'UseScim': true,\n 'UseGroups': true,\n 'UseEvents': true,\n 'UseDirectory': true,\n 'UseTotp': true,\n 'Use2fa': true,\n 'UseApi': true,\n 'UseResetPassword': true,\n 'MaxStorageGb': 100,\n 'SelfHost': true,\n 'UsersGetPremium': true,\n 'UseCustomPermissions': true,\n 'Version': 14,\n 'Issued': '2023-12-14T02:03:33.374297Z',\n 'Refresh': '2023-12-07T22:42:33.970597Z',\n 'Expires': '2023-12-21T02:03:33.374297Z',\n 'ExpirationWithoutGracePeriod': null,\n 'UsePasswordManager': true,\n 'UseSecretsManager': true,\n 'SmSeats': 5,\n 'SmServiceAccounts': 8,\n 'LimitCollectionCreationDeletion': true,\n 'AllowAdminAccessToAllCollectionItems': true,\n 'Trial': true,\n 'LicenseType': 1,\n 'Hash': 'EZl4IvJaa1E5mPmlfp4p5twAtlmaxlF1yoZzVYP4vog=',\n 'Signature': ''\n}"; "{\n 'LicenseKey': 'myLicenseKey',\n 'InstallationId': '78900000-0000-0000-0000-000000000123',\n 'Id': '12300000-0000-0000-0000-000000000456',\n 'Name': 'myOrg',\n 'BillingEmail': 'myBillingEmail',\n 'BusinessName': 'myBusinessName',\n 'Enabled': true,\n 'Plan': 'myPlan',\n 'PlanType': 11,\n 'Seats': 10,\n 'MaxCollections': 2,\n 'UsePolicies': true,\n 'UseSso': true,\n 'UseKeyConnector': true,\n 'UseScim': true,\n 'UseGroups': true,\n 'UseEvents': true,\n 'UseDirectory': true,\n 'UseTotp': true,\n 'Use2fa': true,\n 'UseApi': true,\n 'UseResetPassword': true,\n 'MaxStorageGb': 100,\n 'SelfHost': true,\n 'UsersGetPremium': true,\n 'UseCustomPermissions': true,\n 'Version': 14,\n 'Issued': '2023-12-14T02:03:33.374297Z',\n 'Refresh': '2023-12-07T22:42:33.970597Z',\n 'Expires': '2023-12-21T02:03:33.374297Z',\n 'ExpirationWithoutGracePeriod': null,\n 'UsePasswordManager': true,\n 'UseSecretsManager': true,\n 'SmSeats': 5,\n 'SmServiceAccounts': 8,\n 'LimitCollectionCreationDeletion': true,\n 'AllowAdminAccessToAllCollectionItems': true,\n 'Trial': true,\n 'LicenseType': 1,\n 'Hash': 'EZl4IvJaa1E5mPmlfp4p5twAtlmaxlF1yoZzVYP4vog=',\n 'Signature': ''\n}";
private static readonly Dictionary<int, string> LicenseVersions = new() { { 12, Version12 }, { 13, Version13 }, { 14, Version14 }, { 15, Version15 } }; private const string Version16 =
"{\n'LicenseKey': 'myLicenseKey',\n'InstallationId': '78900000-0000-0000-0000-000000000123',\n'Id': '12300000-0000-0000-0000-000000000456',\n'Name': 'myOrg',\n'BillingEmail': 'myBillingEmail',\n'BusinessName': 'myBusinessName',\n'Enabled': true,\n'Plan': 'myPlan',\n'PlanType': 11,\n'Seats': 10,\n'MaxCollections': 2,\n'UsePolicies': true,\n'UseSso': true,\n'UseKeyConnector': true,\n'UseScim': true,\n'UseGroups': true,\n'UseEvents': true,\n'UseDirectory': true,\n'UseTotp': true,\n'Use2fa': true,\n'UseApi': true,\n'UseResetPassword': true,\n'MaxStorageGb': 100,\n'SelfHost': true,\n'UsersGetPremium': true,\n'UseCustomPermissions': true,\n'Version': 15,\n'Issued': '2025-05-16T20:50:09.036931Z',\n'Refresh': '2025-05-23T20:50:09.036931Z',\n'Expires': '2025-05-23T20:50:09.036931Z',\n'ExpirationWithoutGracePeriod': null,\n'UsePasswordManager': true,\n'UseSecretsManager': true,\n'SmSeats': 5,\n'SmServiceAccounts': 8,\n'UseRiskInsights': false,\n'LimitCollectionCreationDeletion': true,\n'AllowAdminAccessToAllCollectionItems': true,\n'Trial': true,\n'LicenseType': 1,\n'UseOrganizationDomains': true,\n'UseAdminSponsoredFamilies': false,\n'Hash': 'k3M9SpHKUo0TmuSnNipeZleCHxgcEycKRXYl9BAg30Q=',\n'Signature': '',\n'Token': null\n}";
private static readonly Dictionary<int, string> LicenseVersions = new() { { 12, Version12 }, { 13, Version13 }, { 14, Version14 }, { 15, Version15 }, { 16, Version16 } };
public static OrganizationLicense GetVersion(int licenseVersion) public static OrganizationLicense GetVersion(int licenseVersion)
{ {

View File

@ -347,7 +347,7 @@ public class UserServiceTests
SutProvider<UserService> sutProvider, Guid userId, Organization organization) SutProvider<UserService> sutProvider, Guid userId, Organization organization)
{ {
organization.Enabled = true; organization.Enabled = true;
organization.UseSso = true; organization.UseOrganizationDomains = true;
sutProvider.GetDependency<IOrganizationRepository>() sutProvider.GetDependency<IOrganizationRepository>()
.GetByVerifiedUserEmailDomainAsync(userId) .GetByVerifiedUserEmailDomainAsync(userId)
@ -362,7 +362,7 @@ public class UserServiceTests
SutProvider<UserService> sutProvider, Guid userId, Organization organization) SutProvider<UserService> sutProvider, Guid userId, Organization organization)
{ {
organization.Enabled = false; organization.Enabled = false;
organization.UseSso = true; organization.UseOrganizationDomains = true;
sutProvider.GetDependency<IOrganizationRepository>() sutProvider.GetDependency<IOrganizationRepository>()
.GetByVerifiedUserEmailDomainAsync(userId) .GetByVerifiedUserEmailDomainAsync(userId)
@ -373,11 +373,11 @@ public class UserServiceTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task IsClaimedByAnyOrganizationAsync_WithOrganizationUseSsoFalse_ReturnsFalse( public async Task IsClaimedByAnyOrganizationAsync_WithOrganizationUseOrganizationDomaisFalse_ReturnsFalse(
SutProvider<UserService> sutProvider, Guid userId, Organization organization) SutProvider<UserService> sutProvider, Guid userId, Organization organization)
{ {
organization.Enabled = true; organization.Enabled = true;
organization.UseSso = false; organization.UseOrganizationDomains = false;
sutProvider.GetDependency<IOrganizationRepository>() sutProvider.GetDependency<IOrganizationRepository>()
.GetByVerifiedUserEmailDomainAsync(userId) .GetByVerifiedUserEmailDomainAsync(userId)

View File

@ -0,0 +1,118 @@
using System.Text.Json;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures.Commands;
using Bit.Core.Tools.Services;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Tools.Services;
public class AnonymousSendCommandTests
{
private readonly ISendRepository _sendRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPushNotificationService _pushNotificationService;
private readonly ISendAuthorizationService _sendAuthorizationService;
private readonly AnonymousSendCommand _anonymousSendCommand;
public AnonymousSendCommandTests()
{
_sendRepository = Substitute.For<ISendRepository>();
_sendFileStorageService = Substitute.For<ISendFileStorageService>();
_pushNotificationService = Substitute.For<IPushNotificationService>();
_sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
_anonymousSendCommand = new AnonymousSendCommand(
_sendRepository,
_sendFileStorageService,
_pushNotificationService,
_sendAuthorizationService);
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_Success_ReturnsDownloadUrl()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.File,
AccessCount = 0,
Data = JsonSerializer.Serialize(new { Id = "fileId123" })
};
var fileId = "fileId123";
var password = "testPassword";
var expectedUrl = "https://example.com/download";
_sendAuthorizationService
.SendCanBeAccessed(send, password)
.Returns(SendAccessResult.Granted);
_sendFileStorageService
.GetSendFileDownloadUrlAsync(send, fileId)
.Returns(expectedUrl);
// Act
var result =
await _anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId, password);
// Assert
Assert.Equal(expectedUrl, result.Item1);
Assert.Equal(1, send.AccessCount);
await _sendRepository.Received(1).ReplaceAsync(send);
await _pushNotificationService.Received(1).PushSyncSendUpdateAsync(send);
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_AccessDenied_ReturnsNullWithReasons()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.File,
AccessCount = 0
};
var fileId = "fileId123";
var password = "wrongPassword";
_sendAuthorizationService
.SendCanBeAccessed(send, password)
.Returns(SendAccessResult.Denied);
// Act
var result =
await _anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId, password);
// Assert
Assert.Null(result.Item1);
Assert.Equal(SendAccessResult.Denied, result.Item2);
Assert.Equal(0, send.AccessCount);
await _sendRepository.DidNotReceiveWithAnyArgs().ReplaceAsync(default);
await _pushNotificationService.DidNotReceiveWithAnyArgs().PushSyncSendUpdateAsync(default);
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_NotFileSend_ThrowsBadRequestException()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.Text
};
var fileId = "fileId123";
var password = "testPassword";
// Act & Assert
await Assert.ThrowsAsync<BadRequestException>(() =>
_anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId, password));
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,175 @@
using Bit.Core.Context;
using Bit.Core.Platform.Push;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.Services;
using Microsoft.AspNetCore.Identity;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Tools.Services;
public class SendAuthorizationServiceTests
{
private readonly ISendRepository _sendRepository;
private readonly IPasswordHasher<Bit.Core.Entities.User> _passwordHasher;
private readonly IPushNotificationService _pushNotificationService;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
private readonly SendAuthorizationService _sendAuthorizationService;
public SendAuthorizationServiceTests()
{
_sendRepository = Substitute.For<ISendRepository>();
_passwordHasher = Substitute.For<IPasswordHasher<Bit.Core.Entities.User>>();
_pushNotificationService = Substitute.For<IPushNotificationService>();
_referenceEventService = Substitute.For<IReferenceEventService>();
_currentContext = Substitute.For<ICurrentContext>();
_sendAuthorizationService = new SendAuthorizationService(
_sendRepository,
_passwordHasher,
_pushNotificationService,
_referenceEventService,
_currentContext);
}
[Fact]
public void SendCanBeAccessed_Success_ReturnsTrue()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
MaxAccessCount = 10,
AccessCount = 5,
ExpirationDate = DateTime.UtcNow.AddYears(1),
DeletionDate = DateTime.UtcNow.AddYears(1),
Disabled = false,
Password = "hashedPassword123"
};
const string password = "TEST";
_passwordHasher
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), send.Password, password)
.Returns(PasswordVerificationResult.Success);
// Act
var result =
_sendAuthorizationService.SendCanBeAccessed(send, password);
// Assert
Assert.Equal(SendAccessResult.Granted, result);
}
[Fact]
public void SendCanBeAccessed_NullMaxAccess_Success()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
MaxAccessCount = null,
AccessCount = 5,
ExpirationDate = DateTime.UtcNow.AddYears(1),
DeletionDate = DateTime.UtcNow.AddYears(1),
Disabled = false,
Password = "hashedPassword123"
};
const string password = "TEST";
_passwordHasher
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), send.Password, password)
.Returns(PasswordVerificationResult.Success);
// Act
var result = _sendAuthorizationService.SendCanBeAccessed(send, password);
// Assert
Assert.Equal(SendAccessResult.Granted, result);
}
[Fact]
public void SendCanBeAccessed_NullSend_DoesNotGrantAccess()
{
// Arrange
_passwordHasher
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.Success);
// Act
var result =
_sendAuthorizationService.SendCanBeAccessed(null, "TEST");
// Assert
Assert.Equal(SendAccessResult.Denied, result);
}
[Fact]
public void SendCanBeAccessed_RehashNeeded_RehashesPassword()
{
// Arrange
var now = DateTime.UtcNow;
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
MaxAccessCount = null,
AccessCount = 5,
ExpirationDate = now.AddYears(1),
DeletionDate = now.AddYears(1),
Disabled = false,
Password = "TEST"
};
_passwordHasher
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.SuccessRehashNeeded);
// Act
var result =
_sendAuthorizationService.SendCanBeAccessed(send, "TEST");
// Assert
_passwordHasher
.Received(1)
.HashPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST");
Assert.Equal(SendAccessResult.Granted, result);
}
[Fact]
public void SendCanBeAccessed_VerifyFailed_PasswordInvalidReturnsTrue()
{
// Arrange
var now = DateTime.UtcNow;
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
MaxAccessCount = null,
AccessCount = 5,
ExpirationDate = now.AddYears(1),
DeletionDate = now.AddYears(1),
Disabled = false,
Password = "TEST"
};
_passwordHasher
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.Failed);
// Act
var result =
_sendAuthorizationService.SendCanBeAccessed(send, "TEST");
// Assert
Assert.Equal(SendAccessResult.PasswordInvalid, result);
}
}

View File

@ -1,867 +0,0 @@
using System.Text;
using System.Text.Json;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Test.AutoFixture.CurrentContextFixtures;
using Bit.Core.Test.Entities;
using Bit.Core.Test.Tools.AutoFixture.SendFixtures;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Identity;
using NSubstitute;
using NSubstitute.ExceptionExtensions;
using Xunit;
using GlobalSettings = Bit.Core.Settings.GlobalSettings;
namespace Bit.Core.Test.Tools.Services;
[SutProviderCustomize]
[CurrentContextCustomize]
[UserSendCustomize]
public class SendServiceTests
{
private void SaveSendAsync_Setup(SendType sendType, bool disableSendPolicyAppliesToUser,
SutProvider<SendService> sutProvider, Send send)
{
send.Id = default;
send.Type = sendType;
sutProvider.GetDependency<IPolicyService>().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.DisableSend).Returns(disableSendPolicyAppliesToUser);
}
// Disable Send policy check
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableSend_Applies_throws(SendType sendType,
SutProvider<SendService> sutProvider, Send send)
{
SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: true, sutProvider, send);
await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableSend_DoesntApply_success(SendType sendType,
SutProvider<SendService> sutProvider, Send send)
{
SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: false, sutProvider, send);
await sutProvider.Sut.SaveSendAsync(send);
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
}
// Send Options Policy - Disable Hide Email check
private void SaveSendAsync_HideEmail_Setup(bool disableHideEmailAppliesToUser,
SutProvider<SendService> sutProvider, Send send, Policy policy)
{
send.HideEmail = true;
var sendOptions = new SendOptionsPolicyData
{
DisableHideEmail = disableHideEmailAppliesToUser
};
policy.Data = JsonSerializer.Serialize(sendOptions, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
});
sutProvider.GetDependency<IPolicyService>().GetPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.SendOptions).Returns(new List<OrganizationUserPolicyDetails>()
{
new() { PolicyType = policy.Type, PolicyData = policy.Data, OrganizationId = policy.OrganizationId, PolicyEnabled = policy.Enabled }
});
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableHideEmail_Applies_throws(SendType sendType,
SutProvider<SendService> sutProvider, Send send, Policy policy)
{
SaveSendAsync_Setup(sendType, false, sutProvider, send);
SaveSendAsync_HideEmail_Setup(true, sutProvider, send, policy);
await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableHideEmail_DoesntApply_success(SendType sendType,
SutProvider<SendService> sutProvider, Send send, Policy policy)
{
SaveSendAsync_Setup(sendType, false, sutProvider, send);
SaveSendAsync_HideEmail_Setup(false, sutProvider, send, policy);
await sutProvider.Sut.SaveSendAsync(send);
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
}
// Disable Send policy check - vNext
private void SaveSendAsync_Setup_vNext(SutProvider<SendService> sutProvider, Send send,
DisableSendPolicyRequirement disableSendPolicyRequirement, SendOptionsPolicyRequirement sendOptionsPolicyRequirement)
{
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<DisableSendPolicyRequirement>(send.UserId!.Value)
.Returns(disableSendPolicyRequirement);
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<SendOptionsPolicyRequirement>(send.UserId!.Value)
.Returns(sendOptionsPolicyRequirement);
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
// Should not be called in these tests
sutProvider.GetDependency<IPolicyService>().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>()).ThrowsAsync<Exception>();
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableSend_Applies_Throws_vNext(SendType sendType,
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
{
send.Type = sendType;
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement { DisableSend = true }, new SendOptionsPolicyRequirement());
var exception = await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
Assert.Contains("Due to an Enterprise Policy, you are only able to delete an existing Send.",
exception.Message);
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableSend_DoesntApply_Success_vNext(SendType sendType,
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
{
send.Type = sendType;
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement());
await sutProvider.Sut.SaveSendAsync(send);
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
}
// Send Options Policy - Disable Hide Email check
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableHideEmail_Applies_Throws_vNext(SendType sendType,
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
{
send.Type = sendType;
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement { DisableHideEmail = true });
send.HideEmail = true;
var exception = await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
Assert.Contains("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.", exception.Message);
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableHideEmail_Applies_ButEmailNotHidden_Success_vNext(SendType sendType,
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
{
send.Type = sendType;
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement { DisableHideEmail = true });
send.HideEmail = false;
await sutProvider.Sut.SaveSendAsync(send);
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableHideEmail_DoesntApply_Success_vNext(SendType sendType,
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
{
send.Type = sendType;
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement());
send.HideEmail = true;
await sutProvider.Sut.SaveSendAsync(send);
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
}
[Theory]
[BitAutoData]
public async Task SaveSendAsync_ExistingSend_Updates(SutProvider<SendService> sutProvider,
Send send)
{
send.Id = Guid.NewGuid();
var now = DateTime.UtcNow;
await sutProvider.Sut.SaveSendAsync(send);
Assert.True(send.RevisionDate - now < TimeSpan.FromSeconds(1));
await sutProvider.GetDependency<ISendRepository>()
.Received(1)
.UpsertAsync(send);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushSyncSendUpdateAsync(send);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_TextType_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
send.Type = SendType.Text;
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 0)
);
Assert.Contains("not of type \"file\"", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_EmptyFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
send.Type = SendType.File;
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 0)
);
Assert.Contains("no file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserCannotAccessPremium_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(false);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("must have premium", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserHasUnconfirmedEmail_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = false,
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("must confirm your email", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserCanAccessPremium_HasNoStorage_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
Premium = true,
MaxStorageGb = null,
Storage = 0,
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserCanAccessPremium_StorageFull_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
Premium = true,
MaxStorageGb = 2,
Storage = 2 * UserTests.Multiplier,
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsSelfHosted_GiantFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
Premium = false,
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<GlobalSettings>()
.SelfHosted = true;
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 11000 * UserTests.Multiplier)
);
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsNotSelfHosted_TwoGigabyteFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
Premium = false,
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<GlobalSettings>()
.SelfHosted = false;
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier)
);
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var org = new Organization
{
Id = Guid.NewGuid(),
MaxStorageGb = null,
};
send.UserId = null;
send.OrganizationId = org.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(org.Id)
.Returns(org);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_TwoGBFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var org = new Organization
{
Id = Guid.NewGuid(),
MaxStorageGb = null,
};
send.UserId = null;
send.OrganizationId = org.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(org.Id)
.Returns(org);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_ThroughOrg_MaxStorageIsOneGB_TwoGBFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var org = new Organization
{
Id = Guid.NewGuid(),
MaxStorageGb = 1,
};
send.UserId = null;
send.OrganizationId = org.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(org.Id)
.Returns(org);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier)
);
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_HasEnoughStorage_Success(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
MaxStorageGb = 10,
};
var data = new SendFileData
{
};
send.UserId = user.Id;
send.Type = SendType.File;
var testUrl = "https://test.com/";
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<ISendFileStorageService>()
.GetSendFileUploadUrlAsync(send, Arg.Any<string>())
.Returns(testUrl);
var utcNow = DateTime.UtcNow;
var url = await sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier);
Assert.Equal(testUrl, url);
Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1));
await sutProvider.GetDependency<ISendFileStorageService>()
.Received(1)
.GetSendFileUploadUrlAsync(send, Arg.Any<string>());
await sutProvider.GetDependency<ISendRepository>()
.Received(1)
.UpsertAsync(send);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushSyncSendUpdateAsync(send);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_HasEnoughStorage_SendFileThrows_CleansUp(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
MaxStorageGb = 10,
};
var data = new SendFileData
{
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<ISendFileStorageService>()
.GetSendFileUploadUrlAsync(send, Arg.Any<string>())
.Returns<string>(callInfo => throw new Exception("Problem"));
var utcNow = DateTime.UtcNow;
var exception = await Assert.ThrowsAsync<Exception>(() =>
sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier)
);
Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1));
Assert.Equal("Problem", exception.Message);
await sutProvider.GetDependency<ISendFileStorageService>()
.Received(1)
.GetSendFileUploadUrlAsync(send, Arg.Any<string>());
await sutProvider.GetDependency<ISendRepository>()
.Received(1)
.UpsertAsync(send);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushSyncSendUpdateAsync(send);
await sutProvider.GetDependency<ISendFileStorageService>()
.Received(1)
.DeleteFileAsync(send, Arg.Any<string>());
}
[Theory]
[BitAutoData]
public async Task UpdateFileToExistingSendAsync_SendNull_ThrowsBadRequest(SutProvider<SendService> sutProvider)
{
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), null)
);
Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task UpdateFileToExistingSendAsync_SendDataNull_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
send.Data = null;
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send)
);
Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task UpdateFileToExistingSendAsync_NotFileType_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send)
);
Assert.Contains("not a file type send", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task UpdateFileToExistingSendAsync_Success(SutProvider<SendService> sutProvider,
Send send)
{
var fileContents = "Test file content";
var sendFileData = new SendFileData
{
Id = "TEST",
Size = fileContents.Length,
Validated = false,
};
send.Type = SendType.File;
send.Data = JsonSerializer.Serialize(sendFileData);
sutProvider.GetDependency<ISendFileStorageService>()
.ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any<long>())
.Returns((true, sendFileData.Size));
await sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send);
}
[Theory]
[BitAutoData]
public async Task UpdateFileToExistingSendAsync_InvalidSize(SutProvider<SendService> sutProvider,
Send send)
{
var fileContents = "Test file content";
var sendFileData = new SendFileData
{
Id = "TEST",
Size = fileContents.Length,
};
send.Type = SendType.File;
send.Data = JsonSerializer.Serialize(sendFileData);
sutProvider.GetDependency<ISendFileStorageService>()
.ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any<long>())
.Returns((false, sendFileData.Size));
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send)
);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_Success(SutProvider<SendService> sutProvider, Send send)
{
var now = DateTime.UtcNow;
send.MaxAccessCount = 10;
send.AccessCount = 5;
send.ExpirationDate = now.AddYears(1);
send.DeletionDate = now.AddYears(1);
send.Disabled = false;
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), send.Password, "TEST")
.Returns(PasswordVerificationResult.Success);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
Assert.True(grant);
Assert.False(passwordRequiredError);
Assert.False(passwordInvalidError);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_NullMaxAccess_Success(SutProvider<SendService> sutProvider,
Send send)
{
var now = DateTime.UtcNow;
send.MaxAccessCount = null;
send.AccessCount = 5;
send.ExpirationDate = now.AddYears(1);
send.DeletionDate = now.AddYears(1);
send.Disabled = false;
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), send.Password, "TEST")
.Returns(PasswordVerificationResult.Success);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
Assert.True(grant);
Assert.False(passwordRequiredError);
Assert.False(passwordInvalidError);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_NullSend_DoesNotGrantAccess(SutProvider<SendService> sutProvider)
{
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.Success);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(null, "TEST");
Assert.False(grant);
Assert.False(passwordRequiredError);
Assert.False(passwordInvalidError);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_NullPassword_PasswordRequiredErrorReturnsTrue(SutProvider<SendService> sutProvider,
Send send)
{
var now = DateTime.UtcNow;
send.MaxAccessCount = null;
send.AccessCount = 5;
send.ExpirationDate = now.AddYears(1);
send.DeletionDate = now.AddYears(1);
send.Disabled = false;
send.Password = "HASH";
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.Success);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(send, null);
Assert.False(grant);
Assert.True(passwordRequiredError);
Assert.False(passwordInvalidError);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_RehashNeeded_RehashesPassword(SutProvider<SendService> sutProvider,
Send send)
{
var now = DateTime.UtcNow;
send.MaxAccessCount = null;
send.AccessCount = 5;
send.ExpirationDate = now.AddYears(1);
send.DeletionDate = now.AddYears(1);
send.Disabled = false;
send.Password = "TEST";
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.SuccessRehashNeeded);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
sutProvider.GetDependency<IPasswordHasher<User>>()
.Received(1)
.HashPassword(Arg.Any<User>(), "TEST");
Assert.True(grant);
Assert.False(passwordRequiredError);
Assert.False(passwordInvalidError);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_VerifyFailed_PasswordInvalidReturnsTrue(SutProvider<SendService> sutProvider,
Send send)
{
var now = DateTime.UtcNow;
send.MaxAccessCount = null;
send.AccessCount = 5;
send.ExpirationDate = now.AddYears(1);
send.DeletionDate = now.AddYears(1);
send.Disabled = false;
send.Password = "TEST";
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.Failed);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
Assert.False(grant);
Assert.False(passwordRequiredError);
Assert.True(passwordInvalidError);
}
}

View File

@ -57,8 +57,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
var localFactory = new IdentityApplicationFactory(); var localFactory = new IdentityApplicationFactory();
var user = await localFactory.RegisterNewIdentityFactoryUserAsync(requestModel); var user = await localFactory.RegisterNewIdentityFactoryUserAsync(requestModel);
var context = await PostLoginAsync(localFactory.Server, user, requestModel.MasterPasswordHash, var context = await PostLoginAsync(localFactory.Server, user, requestModel.MasterPasswordHash);
context => context.SetAuthEmail(user.Email));
using var body = await AssertDefaultTokenBodyAsync(context); using var body = await AssertDefaultTokenBodyAsync(context);
var root = body.RootElement; var root = body.RootElement;
@ -72,71 +71,6 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
AssertUserDecryptionOptions(root); AssertUserDecryptionOptions(root);
} }
[Theory, BitAutoData, RegisterFinishRequestModelCustomize]
public async Task TokenEndpoint_GrantTypePassword_NoAuthEmailHeader_Fails(
RegisterFinishRequestModel requestModel)
{
requestModel.Email = "test+noauthemailheader@email.com";
var localFactory = new IdentityApplicationFactory();
var user = await localFactory.RegisterNewIdentityFactoryUserAsync(requestModel);
var context = await PostLoginAsync(localFactory.Server, user, requestModel.MasterPasswordHash, null);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
var root = body.RootElement;
var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString();
Assert.Equal("invalid_grant", error);
AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String);
}
[Theory, BitAutoData, RegisterFinishRequestModelCustomize]
public async Task TokenEndpoint_GrantTypePassword_InvalidBase64AuthEmailHeader_Fails(
RegisterFinishRequestModel requestModel)
{
requestModel.Email = "test+badauthheader@email.com";
var localFactory = new IdentityApplicationFactory();
var user = await localFactory.RegisterNewIdentityFactoryUserAsync(requestModel);
var context = await PostLoginAsync(localFactory.Server, user, requestModel.MasterPasswordHash,
context => context.Request.Headers.Append("Auth-Email", "bad_value"));
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
var root = body.RootElement;
var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString();
Assert.Equal("invalid_grant", error);
AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String);
}
[Theory, BitAutoData, RegisterFinishRequestModelCustomize]
public async Task TokenEndpoint_GrantTypePassword_WrongAuthEmailHeader_Fails(
RegisterFinishRequestModel requestModel)
{
requestModel.Email = "test+badauthheader@email.com";
var localFactory = new IdentityApplicationFactory();
var user = await localFactory.RegisterNewIdentityFactoryUserAsync(requestModel);
var context = await PostLoginAsync(localFactory.Server, user, requestModel.MasterPasswordHash,
context => context.SetAuthEmail("bad_value"));
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
var root = body.RootElement;
var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString();
Assert.Equal("invalid_grant", error);
AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String);
}
[Theory, RegisterFinishRequestModelCustomize] [Theory, RegisterFinishRequestModelCustomize]
[BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Owner)]
[BitAutoData(OrganizationUserType.Admin)] [BitAutoData(OrganizationUserType.Admin)]
@ -157,8 +91,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
await CreateOrganizationWithSsoPolicyAsync(localFactory, await CreateOrganizationWithSsoPolicyAsync(localFactory,
organizationId, user.Email, organizationUserType, ssoPolicyEnabled: false); organizationId, user.Email, organizationUserType, ssoPolicyEnabled: false);
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash, var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash);
context => context.SetAuthEmail(user.Email));
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
} }
@ -184,8 +117,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
await CreateOrganizationWithSsoPolicyAsync( await CreateOrganizationWithSsoPolicyAsync(
localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: false); localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: false);
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash, var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash);
context => context.SetAuthEmail(user.Email));
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
} }
@ -209,8 +141,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
await CreateOrganizationWithSsoPolicyAsync(localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: true); await CreateOrganizationWithSsoPolicyAsync(localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: true);
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash, var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash);
context => context.SetAuthEmail(user.Email));
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await AssertRequiredSsoAuthenticationResponseAsync(context); await AssertRequiredSsoAuthenticationResponseAsync(context);
@ -234,8 +165,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
await CreateOrganizationWithSsoPolicyAsync(localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: true); await CreateOrganizationWithSsoPolicyAsync(localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: true);
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash, var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash);
context => context.SetAuthEmail(user.Email));
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
} }
@ -258,8 +188,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
await CreateOrganizationWithSsoPolicyAsync(localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: true); await CreateOrganizationWithSsoPolicyAsync(localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: true);
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash, var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash);
context => context.SetAuthEmail(user.Email));
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await AssertRequiredSsoAuthenticationResponseAsync(context); await AssertRequiredSsoAuthenticationResponseAsync(context);
@ -342,7 +271,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
{ "grant_type", "password" }, { "grant_type", "password" },
{ "username", model.Email }, { "username", model.Email },
{ "password", model.MasterPasswordHash }, { "password", model.MasterPasswordHash },
}), context => context.SetAuthEmail(model.Email)); }));
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
@ -554,12 +483,12 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
{ "grant_type", "password" }, { "grant_type", "password" },
{ "username", user.Email}, { "username", user.Email},
{ "password", "master_password_hash" }, { "password", "master_password_hash" },
}), context => context.SetAuthEmail(user.Email).SetIp("1.1.1.2")); }), context => context.SetIp("1.1.1.2"));
} }
} }
private async Task<HttpContext> PostLoginAsync( private async Task<HttpContext> PostLoginAsync(
TestServer server, User user, string MasterPasswordHash, Action<HttpContext> extraConfiguration) TestServer server, User user, string MasterPasswordHash)
{ {
return await server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary<string, string> return await server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary<string, string>
{ {
@ -571,7 +500,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
{ "grant_type", "password" }, { "grant_type", "password" },
{ "username", user.Email }, { "username", user.Email },
{ "password", MasterPasswordHash }, { "password", MasterPasswordHash },
}), extraConfiguration); }));
} }
private async Task CreateOrganizationWithSsoPolicyAsync( private async Task CreateOrganizationWithSsoPolicyAsync(

View File

@ -143,7 +143,7 @@ public class IdentityServerTwoFactorTests : IClassFixture<IdentityApplicationFac
{ "grant_type", "password" }, { "grant_type", "password" },
{ "username", _testEmail }, { "username", _testEmail },
{ "password", _testPassword }, { "password", _testPassword },
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(_testEmail))); }));
// Assert // Assert
using var responseBody = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context); using var responseBody = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
@ -263,7 +263,7 @@ public class IdentityServerTwoFactorTests : IClassFixture<IdentityApplicationFac
{ "code", "test_code" }, { "code", "test_code" },
{ "code_verifier", challenge }, { "code_verifier", challenge },
{ "redirect_uri", "https://localhost:8080/sso-connector.html" } { "redirect_uri", "https://localhost:8080/sso-connector.html" }
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(_testEmail))); }));
// Assert // Assert
using var responseBody = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context); using var responseBody = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
@ -307,7 +307,7 @@ public class IdentityServerTwoFactorTests : IClassFixture<IdentityApplicationFac
{ "code", "test_code" }, { "code", "test_code" },
{ "code_verifier", challenge }, { "code_verifier", challenge },
{ "redirect_uri", "https://localhost:8080/sso-connector.html" } { "redirect_uri", "https://localhost:8080/sso-connector.html" }
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(_testEmail))); }));
Assert.Equal(StatusCodes.Status400BadRequest, failedTokenContext.Response.StatusCode); Assert.Equal(StatusCodes.Status400BadRequest, failedTokenContext.Response.StatusCode);
Assert.NotNull(emailToken); Assert.NotNull(emailToken);
@ -326,7 +326,7 @@ public class IdentityServerTwoFactorTests : IClassFixture<IdentityApplicationFac
{ "code", "test_code" }, { "code", "test_code" },
{ "code_verifier", challenge }, { "code_verifier", challenge },
{ "redirect_uri", "https://localhost:8080/sso-connector.html" } { "redirect_uri", "https://localhost:8080/sso-connector.html" }
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(_testEmail))); }));
// Assert // Assert
@ -363,7 +363,7 @@ public class IdentityServerTwoFactorTests : IClassFixture<IdentityApplicationFac
{ "code", "test_code" }, { "code", "test_code" },
{ "code_verifier", challenge }, { "code_verifier", challenge },
{ "redirect_uri", "https://localhost:8080/sso-connector.html" } { "redirect_uri", "https://localhost:8080/sso-connector.html" }
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(_testEmail))); }));
// Assert // Assert
using var responseBody = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context); using var responseBody = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);

View File

@ -29,8 +29,7 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
// Act // Act
var context = await localFactory.Server.PostAsync("/connect/token", var context = await localFactory.Server.PostAsync("/connect/token",
GetFormUrlEncodedContent(), GetFormUrlEncodedContent());
context => context.SetAuthEmail(DefaultUsername));
// Assert // Assert
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context); var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
@ -40,27 +39,6 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
Assert.NotNull(token); Assert.NotNull(token);
} }
[Fact]
public async Task ValidateAsync_AuthEmailHeaderInvalid_InvalidGrantResponse()
{
// Arrange
var localFactory = new IdentityApplicationFactory();
await EnsureUserCreatedAsync(localFactory);
// Act
var context = await localFactory.Server.PostAsync(
"/connect/token",
GetFormUrlEncodedContent()
);
// Assert
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
var root = body.RootElement;
var error = AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String).GetString();
Assert.Equal("Auth-Email header invalid.", error);
}
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task ValidateAsync_UserNull_Failure(string username) public async Task ValidateAsync_UserNull_Failure(string username)
{ {
@ -68,8 +46,7 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
var localFactory = new IdentityApplicationFactory(); var localFactory = new IdentityApplicationFactory();
// Act // Act
var context = await localFactory.Server.PostAsync("/connect/token", var context = await localFactory.Server.PostAsync("/connect/token",
GetFormUrlEncodedContent(username: username), GetFormUrlEncodedContent(username: username));
context => context.SetAuthEmail(username));
// Assert // Assert
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context); var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
@ -106,8 +83,7 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
// Act // Act
var context = await localFactory.Server.PostAsync("/connect/token", var context = await localFactory.Server.PostAsync("/connect/token",
GetFormUrlEncodedContent(password: badPassword), GetFormUrlEncodedContent(password: badPassword));
context => context.SetAuthEmail(DefaultUsername));
// Assert // Assert
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context); var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
@ -155,7 +131,7 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
{ "username", DefaultUsername }, { "username", DefaultUsername },
{ "password", DefaultPassword }, { "password", DefaultPassword },
{ "AuthRequest", authRequest.Id.ToString().ToLowerInvariant() } { "AuthRequest", authRequest.Id.ToString().ToLowerInvariant() }
}), context => context.SetAuthEmail(DefaultUsername)); }));
// Assert // Assert
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context); var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
@ -197,7 +173,7 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
{ "username", DefaultUsername }, { "username", DefaultUsername },
{ "password", DefaultPassword }, { "password", DefaultPassword },
{ "AuthRequest", authRequest.Id.ToString().ToLowerInvariant() } { "AuthRequest", authRequest.Id.ToString().ToLowerInvariant() }
}), context => context.SetAuthEmail(DefaultUsername)); }));
// Assert // Assert

View File

@ -6,6 +6,7 @@ using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Identity.IdentityServer; using Bit.Identity.IdentityServer;
using Bit.Identity.Utilities;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute; using NSubstitute;
using Xunit; using Xunit;
@ -17,6 +18,7 @@ public class UserDecryptionOptionsBuilderTests
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly IDeviceRepository _deviceRepository; private readonly IDeviceRepository _deviceRepository;
private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly ILoginApprovingClientTypes _loginApprovingClientTypes;
private readonly UserDecryptionOptionsBuilder _builder; private readonly UserDecryptionOptionsBuilder _builder;
public UserDecryptionOptionsBuilderTests() public UserDecryptionOptionsBuilderTests()
@ -24,7 +26,8 @@ public class UserDecryptionOptionsBuilderTests
_currentContext = Substitute.For<ICurrentContext>(); _currentContext = Substitute.For<ICurrentContext>();
_deviceRepository = Substitute.For<IDeviceRepository>(); _deviceRepository = Substitute.For<IDeviceRepository>();
_organizationUserRepository = Substitute.For<IOrganizationUserRepository>(); _organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
_builder = new UserDecryptionOptionsBuilder(_currentContext, _deviceRepository, _organizationUserRepository); _loginApprovingClientTypes = Substitute.For<ILoginApprovingClientTypes>();
_builder = new UserDecryptionOptionsBuilder(_currentContext, _deviceRepository, _organizationUserRepository, _loginApprovingClientTypes);
} }
[Theory] [Theory]
@ -120,17 +123,17 @@ public class UserDecryptionOptionsBuilderTests
[BitAutoData(DeviceType.SafariBrowser)] [BitAutoData(DeviceType.SafariBrowser)]
[BitAutoData(DeviceType.VivaldiBrowser)] [BitAutoData(DeviceType.VivaldiBrowser)]
[BitAutoData(DeviceType.UnknownBrowser)] [BitAutoData(DeviceType.UnknownBrowser)]
// Extension
[BitAutoData(DeviceType.ChromeExtension)]
[BitAutoData(DeviceType.FirefoxExtension)]
[BitAutoData(DeviceType.OperaExtension)]
[BitAutoData(DeviceType.EdgeExtension)]
[BitAutoData(DeviceType.VivaldiExtension)]
[BitAutoData(DeviceType.SafariExtension)]
public async Task Build_WhenHasLoginApprovingDevice_ShouldApprovingDeviceTrue( public async Task Build_WhenHasLoginApprovingDevice_ShouldApprovingDeviceTrue(
DeviceType deviceType, DeviceType deviceType,
SsoConfig ssoConfig, SsoConfigurationData configurationData, User user, Device device, Device approvingDevice) SsoConfig ssoConfig, SsoConfigurationData configurationData, User user, Device device, Device approvingDevice)
{ {
_loginApprovingClientTypes.TypesThatCanApprove.Returns(new List<ClientType>
{
ClientType.Desktop,
ClientType.Mobile,
ClientType.Web,
});
configurationData.MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption; configurationData.MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption;
ssoConfig.Data = configurationData.Serialize(); ssoConfig.Data = configurationData.Serialize();
approvingDevice.Type = deviceType; approvingDevice.Type = deviceType;
@ -142,9 +145,65 @@ public class UserDecryptionOptionsBuilderTests
} }
[Theory] [Theory]
// Desktop
[BitAutoData(DeviceType.LinuxDesktop)]
[BitAutoData(DeviceType.MacOsDesktop)]
[BitAutoData(DeviceType.WindowsDesktop)]
[BitAutoData(DeviceType.UWP)]
// Mobile
[BitAutoData(DeviceType.Android)]
[BitAutoData(DeviceType.iOS)]
[BitAutoData(DeviceType.AndroidAmazon)]
// Web
[BitAutoData(DeviceType.ChromeBrowser)]
[BitAutoData(DeviceType.FirefoxBrowser)]
[BitAutoData(DeviceType.OperaBrowser)]
[BitAutoData(DeviceType.EdgeBrowser)]
[BitAutoData(DeviceType.IEBrowser)]
[BitAutoData(DeviceType.SafariBrowser)]
[BitAutoData(DeviceType.VivaldiBrowser)]
[BitAutoData(DeviceType.UnknownBrowser)]
// Extension
[BitAutoData(DeviceType.ChromeExtension)]
[BitAutoData(DeviceType.FirefoxExtension)]
[BitAutoData(DeviceType.OperaExtension)]
[BitAutoData(DeviceType.EdgeExtension)]
[BitAutoData(DeviceType.VivaldiExtension)]
[BitAutoData(DeviceType.SafariExtension)]
public async Task Build_WhenHasLoginApprovingDeviceFeatureFlag_ShouldApprovingDeviceTrue(
DeviceType deviceType,
SsoConfig ssoConfig, SsoConfigurationData configurationData, User user, Device device, Device approvingDevice)
{
_loginApprovingClientTypes.TypesThatCanApprove.Returns(new List<ClientType>
{
ClientType.Desktop,
ClientType.Mobile,
ClientType.Web,
ClientType.Browser,
});
configurationData.MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption;
ssoConfig.Data = configurationData.Serialize();
approvingDevice.Type = deviceType;
_deviceRepository.GetManyByUserIdAsync(user.Id).Returns(new Device[] { approvingDevice });
var result = await _builder.ForUser(user).WithSso(ssoConfig).WithDevice(device).BuildAsync();
Assert.True(result.TrustedDeviceOption?.HasLoginApprovingDevice);
}
[Theory]
// CLI
[BitAutoData(DeviceType.WindowsCLI)] [BitAutoData(DeviceType.WindowsCLI)]
[BitAutoData(DeviceType.MacOsCLI)] [BitAutoData(DeviceType.MacOsCLI)]
[BitAutoData(DeviceType.LinuxCLI)] [BitAutoData(DeviceType.LinuxCLI)]
// Extension
[BitAutoData(DeviceType.ChromeExtension)]
[BitAutoData(DeviceType.FirefoxExtension)]
[BitAutoData(DeviceType.OperaExtension)]
[BitAutoData(DeviceType.EdgeExtension)]
[BitAutoData(DeviceType.VivaldiExtension)]
[BitAutoData(DeviceType.SafariExtension)]
public async Task Build_WhenHasLoginApprovingDevice_ShouldApprovingDeviceFalse( public async Task Build_WhenHasLoginApprovingDevice_ShouldApprovingDeviceFalse(
DeviceType deviceType, DeviceType deviceType,
SsoConfig ssoConfig, SsoConfigurationData configurationData, User user, Device device, Device approvingDevice) SsoConfig ssoConfig, SsoConfigurationData configurationData, User user, Device device, Device approvingDevice)

View File

@ -5,10 +5,8 @@ using Bit.Core.Auth.Models.Api.Request.Accounts;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Utilities;
using Bit.Identity; using Bit.Identity;
using Bit.Test.Common.Helpers; using Bit.Test.Common.Helpers;
using HandlebarsDotNet;
using LinqToDB; using LinqToDB;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
@ -98,7 +96,7 @@ public class IdentityApplicationFactory : WebApplicationFactoryBase<Startup>
{ "grant_type", "password" }, { "grant_type", "password" },
{ "username", username }, { "username", username },
{ "password", password }, { "password", password },
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(username))); }));
return context; return context;
} }
@ -126,7 +124,7 @@ public class IdentityApplicationFactory : WebApplicationFactoryBase<Startup>
{ "TwoFactorToken", twoFactorToken }, { "TwoFactorToken", twoFactorToken },
{ "TwoFactorProvider", twoFactorProviderType }, { "TwoFactorProvider", twoFactorProviderType },
{ "TwoFactorRemember", "1" }, { "TwoFactorRemember", "1" },
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(username))); }));
return context; return context;
} }

View File

@ -1,5 +1,4 @@
using System.Net; using System.Net;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.Primitives; using Microsoft.Extensions.Primitives;
@ -62,12 +61,6 @@ public static class WebApplicationFactoryExtensions
Action<HttpContext> extraConfiguration = null) Action<HttpContext> extraConfiguration = null)
=> SendAsync(server, HttpMethod.Delete, requestUri, content: content, extraConfiguration); => SendAsync(server, HttpMethod.Delete, requestUri, content: content, extraConfiguration);
public static HttpContext SetAuthEmail(this HttpContext context, string username)
{
context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(username));
return context;
}
public static HttpContext SetIp(this HttpContext context, string ip) public static HttpContext SetIp(this HttpContext context, string ip)
{ {
context.Connection.RemoteIpAddress = IPAddress.Parse(ip); context.Connection.RemoteIpAddress = IPAddress.Parse(ip);