mirror of
https://github.com/bitwarden/server.git
synced 2025-05-22 12:04:27 -05:00
Merge branch 'main' into ac/pm-21612/unified-mariadb-cannot-edit-an-unconfirmed-user
This commit is contained in:
commit
4e727f3660
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
@ -636,7 +636,9 @@ jobs:
|
||||
|
||||
setup-ephemeral-environment:
|
||||
name: Setup Ephemeral Environment
|
||||
needs: build-docker
|
||||
needs:
|
||||
- build-artifacts
|
||||
- build-docker
|
||||
if: |
|
||||
needs.build-artifacts.outputs.has_secrets == 'true'
|
||||
&& github.event_name == 'pull_request'
|
||||
|
@ -3,7 +3,7 @@
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net8.0</TargetFramework>
|
||||
|
||||
<Version>2025.5.0</Version>
|
||||
<Version>2025.5.1</Version>
|
||||
|
||||
<RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
|
@ -8,13 +8,10 @@ using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Extensions;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Billing.Services;
|
||||
using Bit.Core.Billing.Tax.Services;
|
||||
using Bit.Core.Billing.Tax.Services.Implementations;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Stripe;
|
||||
|
||||
namespace Bit.Commercial.Core.AdminConsole.Providers;
|
||||
@ -24,7 +21,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
|
||||
private readonly IEventService _eventService;
|
||||
private readonly IMailService _mailService;
|
||||
private readonly IOrganizationRepository _organizationRepository;
|
||||
private readonly IOrganizationService _organizationService;
|
||||
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
|
||||
private readonly IStripeAdapter _stripeAdapter;
|
||||
private readonly IFeatureService _featureService;
|
||||
@ -32,26 +28,22 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
|
||||
private readonly ISubscriberService _subscriberService;
|
||||
private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery;
|
||||
private readonly IPricingClient _pricingClient;
|
||||
private readonly IAutomaticTaxStrategy _automaticTaxStrategy;
|
||||
|
||||
public RemoveOrganizationFromProviderCommand(
|
||||
IEventService eventService,
|
||||
IMailService mailService,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationService organizationService,
|
||||
IProviderOrganizationRepository providerOrganizationRepository,
|
||||
IStripeAdapter stripeAdapter,
|
||||
IFeatureService featureService,
|
||||
IProviderBillingService providerBillingService,
|
||||
ISubscriberService subscriberService,
|
||||
IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery,
|
||||
IPricingClient pricingClient,
|
||||
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
|
||||
IPricingClient pricingClient)
|
||||
{
|
||||
_eventService = eventService;
|
||||
_mailService = mailService;
|
||||
_organizationRepository = organizationRepository;
|
||||
_organizationService = organizationService;
|
||||
_providerOrganizationRepository = providerOrganizationRepository;
|
||||
_stripeAdapter = stripeAdapter;
|
||||
_featureService = featureService;
|
||||
@ -59,7 +51,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
|
||||
_subscriberService = subscriberService;
|
||||
_hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery;
|
||||
_pricingClient = pricingClient;
|
||||
_automaticTaxStrategy = automaticTaxStrategy;
|
||||
}
|
||||
|
||||
public async Task RemoveOrganizationFromProvider(
|
||||
@ -77,7 +68,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
|
||||
|
||||
if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync(
|
||||
providerOrganization.OrganizationId,
|
||||
Array.Empty<Guid>(),
|
||||
[],
|
||||
includeProvider: false))
|
||||
{
|
||||
throw new BadRequestException("Organization must have at least one confirmed owner.");
|
||||
@ -102,7 +93,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
|
||||
/// <summary>
|
||||
/// When a client organization is unlinked from a provider, we have to check if they're Stripe-enabled
|
||||
/// and, if they are, we remove their MSP discount and set their Subscription to `send_invoice`. This is because
|
||||
/// the provider's payment method will be removed from their Stripe customer causing ensuing charges to fail. Lastly,
|
||||
/// the provider's payment method will be removed from their Stripe customer, causing ensuing charges to fail. Lastly,
|
||||
/// we email the organization owners letting them know they need to add a new payment method.
|
||||
/// </summary>
|
||||
private async Task ResetOrganizationBillingAsync(
|
||||
@ -142,15 +133,18 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
|
||||
Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }]
|
||||
};
|
||||
|
||||
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
|
||||
var setNonUSBusinessUseToReverseCharge = _featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
|
||||
|
||||
if (setNonUSBusinessUseToReverseCharge)
|
||||
{
|
||||
_automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
|
||||
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
|
||||
}
|
||||
else
|
||||
else if (customer.HasRecognizedTaxLocation())
|
||||
{
|
||||
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions
|
||||
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = true
|
||||
Enabled = customer.Address.Country == "US" ||
|
||||
customer.TaxIds.Any()
|
||||
};
|
||||
}
|
||||
|
||||
@ -187,7 +181,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
|
||||
await _mailService.SendProviderUpdatePaymentMethod(
|
||||
organization.Id,
|
||||
organization.Name,
|
||||
provider.Name,
|
||||
provider.Name!,
|
||||
organizationOwnerEmails);
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.AdminConsole.Enums.Provider;
|
||||
using Bit.Core.AdminConsole.Models.Business.Provider;
|
||||
using Bit.Core.AdminConsole.Models.Business.Tokenables;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
using Bit.Core.AdminConsole.Services;
|
||||
using Bit.Core.Billing.Enums;
|
||||
@ -53,6 +54,7 @@ public class ProviderService : IProviderService
|
||||
private readonly IApplicationCacheService _applicationCacheService;
|
||||
private readonly IProviderBillingService _providerBillingService;
|
||||
private readonly IPricingClient _pricingClient;
|
||||
private readonly IProviderClientOrganizationSignUpCommand _providerClientOrganizationSignUpCommand;
|
||||
|
||||
public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository,
|
||||
IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository,
|
||||
@ -61,7 +63,8 @@ public class ProviderService : IProviderService
|
||||
IOrganizationRepository organizationRepository, GlobalSettings globalSettings,
|
||||
ICurrentContext currentContext, IStripeAdapter stripeAdapter, IFeatureService featureService,
|
||||
IDataProtectorTokenFactory<ProviderDeleteTokenable> providerDeleteTokenDataFactory,
|
||||
IApplicationCacheService applicationCacheService, IProviderBillingService providerBillingService, IPricingClient pricingClient)
|
||||
IApplicationCacheService applicationCacheService, IProviderBillingService providerBillingService, IPricingClient pricingClient,
|
||||
IProviderClientOrganizationSignUpCommand providerClientOrganizationSignUpCommand)
|
||||
{
|
||||
_providerRepository = providerRepository;
|
||||
_providerUserRepository = providerUserRepository;
|
||||
@ -81,6 +84,7 @@ public class ProviderService : IProviderService
|
||||
_applicationCacheService = applicationCacheService;
|
||||
_providerBillingService = providerBillingService;
|
||||
_pricingClient = pricingClient;
|
||||
_providerClientOrganizationSignUpCommand = providerClientOrganizationSignUpCommand;
|
||||
}
|
||||
|
||||
public async Task<Provider> CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null)
|
||||
@ -560,12 +564,12 @@ public class ProviderService : IProviderService
|
||||
|
||||
ThrowOnInvalidPlanType(provider.Type, organizationSignup.Plan);
|
||||
|
||||
var (organization, _, defaultCollection) = await _organizationService.SignupClientAsync(organizationSignup);
|
||||
var signUpResponse = await _providerClientOrganizationSignUpCommand.SignUpClientOrganizationAsync(organizationSignup);
|
||||
|
||||
var providerOrganization = new ProviderOrganization
|
||||
{
|
||||
ProviderId = providerId,
|
||||
OrganizationId = organization.Id,
|
||||
OrganizationId = signUpResponse.Organization.Id,
|
||||
Key = organizationSignup.OwnerKey,
|
||||
};
|
||||
|
||||
@ -574,12 +578,12 @@ public class ProviderService : IProviderService
|
||||
|
||||
// Give the owner Can Manage access over the default collection
|
||||
// The orgUser is not available when the org is created so we have to do it here as part of the invite
|
||||
var defaultOwnerAccess = defaultCollection != null
|
||||
var defaultOwnerAccess = signUpResponse.DefaultCollection != null
|
||||
?
|
||||
[
|
||||
new CollectionAccessSelection
|
||||
{
|
||||
Id = defaultCollection.Id,
|
||||
Id = signUpResponse.DefaultCollection.Id,
|
||||
HidePasswords = false,
|
||||
ReadOnly = false,
|
||||
Manage = true
|
||||
@ -587,7 +591,7 @@ public class ProviderService : IProviderService
|
||||
]
|
||||
: Array.Empty<CollectionAccessSelection>();
|
||||
|
||||
await _organizationService.InviteUsersAsync(organization.Id, user.Id, systemUser: null,
|
||||
await _organizationService.InviteUsersAsync(signUpResponse.Organization.Id, user.Id, systemUser: null,
|
||||
new (OrganizationUserInvite, string)[]
|
||||
{
|
||||
(
|
||||
|
@ -67,6 +67,7 @@ public class BusinessUnitConverter(
|
||||
organization.MaxStorageGb = updatedPlan.PasswordManager.BaseStorageGb;
|
||||
organization.UsePolicies = updatedPlan.HasPolicies;
|
||||
organization.UseSso = updatedPlan.HasSso;
|
||||
organization.UseOrganizationDomains = updatedPlan.HasOrganizationDomains;
|
||||
organization.UseGroups = updatedPlan.HasGroups;
|
||||
organization.UseEvents = updatedPlan.HasEvents;
|
||||
organization.UseDirectory = updatedPlan.HasDirectory;
|
||||
|
@ -18,7 +18,6 @@ using Bit.Core.Billing.Services;
|
||||
using Bit.Core.Billing.Services.Contracts;
|
||||
using Bit.Core.Billing.Tax.Models;
|
||||
using Bit.Core.Billing.Tax.Services;
|
||||
using Bit.Core.Billing.Tax.Services.Implementations;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Models.Business;
|
||||
@ -27,7 +26,6 @@ using Bit.Core.Services;
|
||||
using Bit.Core.Settings;
|
||||
using Braintree;
|
||||
using CsvHelper;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Stripe;
|
||||
|
||||
@ -52,8 +50,7 @@ public class ProviderBillingService(
|
||||
ISetupIntentCache setupIntentCache,
|
||||
IStripeAdapter stripeAdapter,
|
||||
ISubscriberService subscriberService,
|
||||
ITaxService taxService,
|
||||
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
|
||||
ITaxService taxService)
|
||||
: IProviderBillingService
|
||||
{
|
||||
public async Task AddExistingOrganization(
|
||||
@ -99,6 +96,7 @@ public class ProviderBillingService(
|
||||
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
|
||||
organization.UsePolicies = plan.HasPolicies;
|
||||
organization.UseSso = plan.HasSso;
|
||||
organization.UseOrganizationDomains = plan.HasOrganizationDomains;
|
||||
organization.UseGroups = plan.HasGroups;
|
||||
organization.UseEvents = plan.HasEvents;
|
||||
organization.UseDirectory = plan.HasDirectory;
|
||||
@ -127,7 +125,7 @@ public class ProviderBillingService(
|
||||
|
||||
/*
|
||||
* We have to scale the provider's seats before the ProviderOrganization
|
||||
* row is inserted so the added organization's seats don't get double counted.
|
||||
* row is inserted so the added organization's seats don't get double-counted.
|
||||
*/
|
||||
await ScaleSeats(provider, organization.PlanType, organization.Seats!.Value);
|
||||
|
||||
@ -235,7 +233,7 @@ public class ProviderBillingService(
|
||||
|
||||
var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions
|
||||
{
|
||||
Expand = ["tax_ids"]
|
||||
Expand = ["tax", "tax_ids"]
|
||||
});
|
||||
|
||||
var providerTaxId = providerCustomer.TaxIds.FirstOrDefault();
|
||||
@ -283,6 +281,13 @@ public class ProviderBillingService(
|
||||
]
|
||||
};
|
||||
|
||||
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
|
||||
|
||||
if (setNonUSBusinessUseToReverseCharge && providerCustomer.Address is not { Country: "US" })
|
||||
{
|
||||
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse;
|
||||
}
|
||||
|
||||
var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
|
||||
|
||||
organization.GatewayCustomerId = customer.Id;
|
||||
@ -519,6 +524,13 @@ public class ProviderBillingService(
|
||||
}
|
||||
};
|
||||
|
||||
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
|
||||
|
||||
if (setNonUSBusinessUseToReverseCharge && taxInfo.BillingAddressCountry != "US")
|
||||
{
|
||||
options.TaxExempt = StripeConstants.TaxExempt.Reverse;
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber))
|
||||
{
|
||||
var taxIdType = taxService.GetStripeTaxCode(
|
||||
@ -530,6 +542,7 @@ public class ProviderBillingService(
|
||||
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
|
||||
taxInfo.BillingAddressCountry,
|
||||
taxInfo.TaxIdNumber);
|
||||
|
||||
throw new BadRequestException("billingTaxIdTypeInferenceError");
|
||||
}
|
||||
|
||||
@ -694,6 +707,13 @@ public class ProviderBillingService(
|
||||
customer.Metadata.ContainsKey(BraintreeCustomerIdKey) ||
|
||||
setupIntent.IsUnverifiedBankAccount());
|
||||
|
||||
int? trialPeriodDays = provider.Type switch
|
||||
{
|
||||
ProviderType.Msp when usePaymentMethod => 14,
|
||||
ProviderType.BusinessUnit when usePaymentMethod => 4,
|
||||
_ => null
|
||||
};
|
||||
|
||||
var subscriptionCreateOptions = new SubscriptionCreateOptions
|
||||
{
|
||||
CollectionMethod = usePaymentMethod ?
|
||||
@ -707,17 +727,24 @@ public class ProviderBillingService(
|
||||
},
|
||||
OffSession = true,
|
||||
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations,
|
||||
TrialPeriodDays = usePaymentMethod ? 14 : null
|
||||
TrialPeriodDays = trialPeriodDays
|
||||
};
|
||||
|
||||
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
|
||||
{
|
||||
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
|
||||
}
|
||||
else
|
||||
var setNonUSBusinessUseToReverseCharge =
|
||||
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
|
||||
|
||||
if (setNonUSBusinessUseToReverseCharge)
|
||||
{
|
||||
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
|
||||
}
|
||||
else if (customer.HasRecognizedTaxLocation())
|
||||
{
|
||||
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = customer.Address.Country == "US" ||
|
||||
customer.TaxIds.Any()
|
||||
};
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
|
@ -1,4 +1,5 @@
|
||||
using Bit.Commercial.Core.AdminConsole.Providers;
|
||||
using Bit.Core;
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.AdminConsole.Enums.Provider;
|
||||
@ -8,7 +9,6 @@ using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Billing.Services;
|
||||
using Bit.Core.Billing.Tax.Services;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Repositories;
|
||||
@ -224,31 +224,115 @@ public class RemoveOrganizationFromProviderCommandTests
|
||||
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(options =>
|
||||
options.Description == string.Empty &&
|
||||
options.Email == organization.BillingEmail &&
|
||||
options.Expand[0] == "tax" &&
|
||||
options.Expand[1] == "tax_ids")).Returns(new Customer
|
||||
{
|
||||
Id = "customer_id",
|
||||
Address = new Address
|
||||
{
|
||||
Country = "US"
|
||||
}
|
||||
});
|
||||
|
||||
stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(new Subscription
|
||||
{
|
||||
Id = "subscription_id"
|
||||
});
|
||||
|
||||
sutProvider.GetDependency<IAutomaticTaxStrategy>()
|
||||
.When(x => x.SetCreateOptions(
|
||||
Arg.Is<SubscriptionCreateOptions>(options =>
|
||||
options.Customer == organization.GatewayCustomerId &&
|
||||
options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice &&
|
||||
options.DaysUntilDue == 30 &&
|
||||
options.Metadata["organizationId"] == organization.Id.ToString() &&
|
||||
options.OffSession == true &&
|
||||
options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
|
||||
options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId &&
|
||||
options.Items.First().Quantity == organization.Seats)
|
||||
, Arg.Any<Customer>()))
|
||||
.Do(x =>
|
||||
await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization);
|
||||
|
||||
await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options =>
|
||||
options.Customer == organization.GatewayCustomerId &&
|
||||
options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice &&
|
||||
options.DaysUntilDue == 30 &&
|
||||
options.AutomaticTax.Enabled == true &&
|
||||
options.Metadata["organizationId"] == organization.Id.ToString() &&
|
||||
options.OffSession == true &&
|
||||
options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
|
||||
options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId &&
|
||||
options.Items.First().Quantity == organization.Seats));
|
||||
|
||||
await sutProvider.GetDependency<IProviderBillingService>().Received(1)
|
||||
.ScaleSeats(provider, organization.PlanType, -organization.Seats ?? 0);
|
||||
|
||||
await organizationRepository.Received(1).ReplaceAsync(Arg.Is<Organization>(
|
||||
org =>
|
||||
org.BillingEmail == "a@example.com" &&
|
||||
org.GatewaySubscriptionId == "subscription_id" &&
|
||||
org.Status == OrganizationStatusType.Created));
|
||||
|
||||
await sutProvider.GetDependency<IProviderOrganizationRepository>().Received(1)
|
||||
.DeleteAsync(providerOrganization);
|
||||
|
||||
await sutProvider.GetDependency<IEventService>().Received(1)
|
||||
.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed);
|
||||
|
||||
await sutProvider.GetDependency<IMailService>().Received(1)
|
||||
.SendProviderUpdatePaymentMethod(
|
||||
organization.Id,
|
||||
organization.Name,
|
||||
provider.Name,
|
||||
Arg.Is<IEnumerable<string>>(emails => emails.FirstOrDefault() == "a@example.com"));
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task RemoveOrganizationFromProvider_OrganizationStripeEnabled_ConsolidatedBilling_ReverseCharge_MakesCorrectInvocations(
|
||||
Provider provider,
|
||||
ProviderOrganization providerOrganization,
|
||||
Organization organization,
|
||||
SutProvider<RemoveOrganizationFromProviderCommand> sutProvider)
|
||||
{
|
||||
provider.Status = ProviderStatusType.Billable;
|
||||
|
||||
providerOrganization.ProviderId = provider.Id;
|
||||
|
||||
organization.Status = OrganizationStatusType.Managed;
|
||||
|
||||
organization.PlanType = PlanType.TeamsMonthly;
|
||||
|
||||
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
|
||||
|
||||
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan);
|
||||
|
||||
sutProvider.GetDependency<IHasConfirmedOwnersExceptQuery>().HasConfirmedOwnersExceptAsync(
|
||||
providerOrganization.OrganizationId,
|
||||
[],
|
||||
includeProvider: false)
|
||||
.Returns(true);
|
||||
|
||||
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
|
||||
|
||||
organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns([
|
||||
"a@example.com",
|
||||
"b@example.com"
|
||||
]);
|
||||
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(options =>
|
||||
options.Description == string.Empty &&
|
||||
options.Email == organization.BillingEmail &&
|
||||
options.Expand[0] == "tax" &&
|
||||
options.Expand[1] == "tax_ids")).Returns(new Customer
|
||||
{
|
||||
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
Id = "customer_id",
|
||||
Address = new Address
|
||||
{
|
||||
Enabled = true
|
||||
};
|
||||
Country = "US"
|
||||
}
|
||||
});
|
||||
|
||||
stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(new Subscription
|
||||
{
|
||||
Id = "subscription_id"
|
||||
});
|
||||
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
|
||||
|
||||
await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization);
|
||||
|
||||
await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options =>
|
||||
|
@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.AdminConsole.Enums.Provider;
|
||||
using Bit.Core.AdminConsole.Models.Business.Provider;
|
||||
using Bit.Core.AdminConsole.Models.Business.Tokenables;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Models;
|
||||
@ -717,8 +718,8 @@ public class ProviderServiceTests
|
||||
|
||||
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
|
||||
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
|
||||
sutProvider.GetDependency<IOrganizationService>().SignupClientAsync(organizationSignup)
|
||||
.Returns((organization, null as OrganizationUser, new Collection()));
|
||||
sutProvider.GetDependency<IProviderClientOrganizationSignUpCommand>().SignUpClientOrganizationAsync(organizationSignup)
|
||||
.Returns(new ProviderClientOrganizationSignUpResponse(organization, new Collection()));
|
||||
|
||||
var providerOrganization =
|
||||
await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user);
|
||||
@ -755,8 +756,8 @@ public class ProviderServiceTests
|
||||
|
||||
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
|
||||
|
||||
sutProvider.GetDependency<IOrganizationService>().SignupClientAsync(organizationSignup)
|
||||
.Returns((organization, null as OrganizationUser, new Collection()));
|
||||
sutProvider.GetDependency<IProviderClientOrganizationSignUpCommand>().SignUpClientOrganizationAsync(organizationSignup)
|
||||
.Returns(new ProviderClientOrganizationSignUpResponse(organization, new Collection()));
|
||||
|
||||
await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user));
|
||||
@ -782,8 +783,8 @@ public class ProviderServiceTests
|
||||
|
||||
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
|
||||
|
||||
sutProvider.GetDependency<IOrganizationService>().SignupClientAsync(organizationSignup)
|
||||
.Returns((organization, null as OrganizationUser, new Collection()));
|
||||
sutProvider.GetDependency<IProviderClientOrganizationSignUpCommand>().SignUpClientOrganizationAsync(organizationSignup)
|
||||
.Returns(new ProviderClientOrganizationSignUpResponse(organization, new Collection()));
|
||||
|
||||
var providerOrganization = await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user);
|
||||
|
||||
@ -821,8 +822,8 @@ public class ProviderServiceTests
|
||||
|
||||
sutProvider.GetDependency<IProviderRepository>().GetByIdAsync(provider.Id).Returns(provider);
|
||||
var providerOrganizationRepository = sutProvider.GetDependency<IProviderOrganizationRepository>();
|
||||
sutProvider.GetDependency<IOrganizationService>().SignupClientAsync(organizationSignup)
|
||||
.Returns((organization, null as OrganizationUser, defaultCollection));
|
||||
sutProvider.GetDependency<IProviderClientOrganizationSignUpCommand>().SignUpClientOrganizationAsync(organizationSignup)
|
||||
.Returns(new ProviderClientOrganizationSignUpResponse(organization, defaultCollection));
|
||||
|
||||
var providerOrganization =
|
||||
await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user);
|
||||
|
@ -262,7 +262,7 @@ public class ProviderBillingServiceTests
|
||||
};
|
||||
|
||||
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>(
|
||||
options => options.Expand.FirstOrDefault() == "tax_ids"))
|
||||
options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids")))
|
||||
.Returns(providerCustomer);
|
||||
|
||||
sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri
|
||||
@ -312,6 +312,91 @@ public class ProviderBillingServiceTests
|
||||
org => org.GatewayCustomerId == "customer_id"));
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task CreateCustomer_ForClientOrg_ReverseCharge_Succeeds(
|
||||
Provider provider,
|
||||
Organization organization,
|
||||
SutProvider<ProviderBillingService> sutProvider)
|
||||
{
|
||||
organization.GatewayCustomerId = null;
|
||||
organization.Name = "Name";
|
||||
organization.BusinessName = "BusinessName";
|
||||
|
||||
var providerCustomer = new Customer
|
||||
{
|
||||
Address = new Address
|
||||
{
|
||||
Country = "CA",
|
||||
PostalCode = "12345",
|
||||
Line1 = "123 Main St.",
|
||||
Line2 = "Unit 4",
|
||||
City = "Fake Town",
|
||||
State = "Fake State"
|
||||
},
|
||||
TaxIds = new StripeList<TaxId>
|
||||
{
|
||||
Data =
|
||||
[
|
||||
new TaxId { Type = "TYPE", Value = "VALUE" }
|
||||
]
|
||||
}
|
||||
};
|
||||
|
||||
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>(
|
||||
options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids")))
|
||||
.Returns(providerCustomer);
|
||||
|
||||
sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri
|
||||
.Returns(new Bit.Core.Settings.GlobalSettings.BaseServiceUriSettings(new Bit.Core.Settings.GlobalSettings())
|
||||
{
|
||||
CloudRegion = "US"
|
||||
});
|
||||
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
|
||||
|
||||
sutProvider.GetDependency<IStripeAdapter>().CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(
|
||||
options =>
|
||||
options.Address.Country == providerCustomer.Address.Country &&
|
||||
options.Address.PostalCode == providerCustomer.Address.PostalCode &&
|
||||
options.Address.Line1 == providerCustomer.Address.Line1 &&
|
||||
options.Address.Line2 == providerCustomer.Address.Line2 &&
|
||||
options.Address.City == providerCustomer.Address.City &&
|
||||
options.Address.State == providerCustomer.Address.State &&
|
||||
options.Name == organization.DisplayName() &&
|
||||
options.Description == $"{provider.Name} Client Organization" &&
|
||||
options.Email == provider.BillingEmail &&
|
||||
options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" &&
|
||||
options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" &&
|
||||
options.Metadata["region"] == "US" &&
|
||||
options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type &&
|
||||
options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value &&
|
||||
options.TaxExempt == StripeConstants.TaxExempt.Reverse))
|
||||
.Returns(new Customer { Id = "customer_id" });
|
||||
|
||||
await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization);
|
||||
|
||||
await sutProvider.GetDependency<IStripeAdapter>().Received(1).CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(
|
||||
options =>
|
||||
options.Address.Country == providerCustomer.Address.Country &&
|
||||
options.Address.PostalCode == providerCustomer.Address.PostalCode &&
|
||||
options.Address.Line1 == providerCustomer.Address.Line1 &&
|
||||
options.Address.Line2 == providerCustomer.Address.Line2 &&
|
||||
options.Address.City == providerCustomer.Address.City &&
|
||||
options.Address.State == providerCustomer.Address.State &&
|
||||
options.Name == organization.DisplayName() &&
|
||||
options.Description == $"{provider.Name} Client Organization" &&
|
||||
options.Email == provider.BillingEmail &&
|
||||
options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" &&
|
||||
options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" &&
|
||||
options.Metadata["region"] == "US" &&
|
||||
options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type &&
|
||||
options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value));
|
||||
|
||||
await sutProvider.GetDependency<IOrganizationRepository>().Received(1).ReplaceAsync(Arg.Is<Organization>(
|
||||
org => org.GatewayCustomerId == "customer_id"));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region GenerateClientInvoiceReport
|
||||
@ -1182,6 +1267,62 @@ public class ProviderBillingServiceTests
|
||||
Assert.Equivalent(expected, actual);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task SetupCustomer_WithCard_ReverseCharge_Success(
|
||||
SutProvider<ProviderBillingService> sutProvider,
|
||||
Provider provider,
|
||||
TaxInfo taxInfo)
|
||||
{
|
||||
provider.Name = "MSP";
|
||||
|
||||
sutProvider.GetDependency<ITaxService>()
|
||||
.GetStripeTaxCode(Arg.Is<string>(
|
||||
p => p == taxInfo.BillingAddressCountry),
|
||||
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
|
||||
.Returns(taxInfo.TaxIdType);
|
||||
|
||||
taxInfo.BillingAddressCountry = "AD";
|
||||
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
var expected = new Customer
|
||||
{
|
||||
Id = "customer_id",
|
||||
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
|
||||
};
|
||||
|
||||
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token");
|
||||
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
|
||||
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
|
||||
|
||||
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
|
||||
o.Address.Country == taxInfo.BillingAddressCountry &&
|
||||
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
|
||||
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
|
||||
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
|
||||
o.Address.City == taxInfo.BillingAddressCity &&
|
||||
o.Address.State == taxInfo.BillingAddressState &&
|
||||
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
|
||||
o.Email == provider.BillingEmail &&
|
||||
o.PaymentMethod == tokenizedPaymentSource.Token &&
|
||||
o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token &&
|
||||
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
|
||||
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
|
||||
o.Metadata["region"] == "" &&
|
||||
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
|
||||
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber &&
|
||||
o.TaxExempt == StripeConstants.TaxExempt.Reverse))
|
||||
.Returns(expected);
|
||||
|
||||
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource);
|
||||
|
||||
Assert.Equivalent(expected, actual);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid(
|
||||
SutProvider<ProviderBillingService> sutProvider,
|
||||
@ -1307,7 +1448,7 @@ public class ProviderBillingServiceTests
|
||||
.Returns(new Customer
|
||||
{
|
||||
Id = "customer_id",
|
||||
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
|
||||
Address = new Address { Country = "US" }
|
||||
});
|
||||
|
||||
var providerPlans = new List<ProviderPlan>
|
||||
@ -1359,7 +1500,7 @@ public class ProviderBillingServiceTests
|
||||
var customer = new Customer
|
||||
{
|
||||
Id = "customer_id",
|
||||
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
|
||||
Address = new Address { Country = "US" }
|
||||
};
|
||||
sutProvider.GetDependency<ISubscriberService>()
|
||||
.GetCustomerOrThrow(
|
||||
@ -1399,19 +1540,6 @@ public class ProviderBillingServiceTests
|
||||
|
||||
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
|
||||
|
||||
sutProvider.GetDependency<IAutomaticTaxStrategy>()
|
||||
.When(x => x.SetCreateOptions(
|
||||
Arg.Is<SubscriptionCreateOptions>(options =>
|
||||
options.Customer == "customer_id")
|
||||
, Arg.Is<Customer>(p => p == customer)))
|
||||
.Do(x =>
|
||||
{
|
||||
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = true
|
||||
};
|
||||
});
|
||||
|
||||
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
|
||||
sub =>
|
||||
sub.AutomaticTax.Enabled == true &&
|
||||
@ -1443,11 +1571,11 @@ public class ProviderBillingServiceTests
|
||||
var customer = new Customer
|
||||
{
|
||||
Id = "customer_id",
|
||||
Address = new Address { Country = "US" },
|
||||
InvoiceSettings = new CustomerInvoiceSettings
|
||||
{
|
||||
DefaultPaymentMethodId = "pm_123"
|
||||
},
|
||||
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
|
||||
}
|
||||
};
|
||||
|
||||
sutProvider.GetDependency<ISubscriberService>()
|
||||
@ -1488,19 +1616,6 @@ public class ProviderBillingServiceTests
|
||||
|
||||
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
|
||||
|
||||
sutProvider.GetDependency<IAutomaticTaxStrategy>()
|
||||
.When(x => x.SetCreateOptions(
|
||||
Arg.Is<SubscriptionCreateOptions>(options =>
|
||||
options.Customer == "customer_id")
|
||||
, Arg.Is<Customer>(p => p == customer)))
|
||||
.Do(x =>
|
||||
{
|
||||
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = true
|
||||
};
|
||||
});
|
||||
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
|
||||
|
||||
@ -1536,9 +1651,9 @@ public class ProviderBillingServiceTests
|
||||
var customer = new Customer
|
||||
{
|
||||
Id = "customer_id",
|
||||
Address = new Address { Country = "US" },
|
||||
InvoiceSettings = new CustomerInvoiceSettings(),
|
||||
Metadata = new Dictionary<string, string>(),
|
||||
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
|
||||
Metadata = new Dictionary<string, string>()
|
||||
};
|
||||
|
||||
sutProvider.GetDependency<ISubscriberService>()
|
||||
@ -1579,19 +1694,6 @@ public class ProviderBillingServiceTests
|
||||
|
||||
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
|
||||
|
||||
sutProvider.GetDependency<IAutomaticTaxStrategy>()
|
||||
.When(x => x.SetCreateOptions(
|
||||
Arg.Is<SubscriptionCreateOptions>(options =>
|
||||
options.Customer == "customer_id")
|
||||
, Arg.Is<Customer>(p => p == customer)))
|
||||
.Do(x =>
|
||||
{
|
||||
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = true
|
||||
};
|
||||
});
|
||||
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
|
||||
|
||||
@ -1646,12 +1748,15 @@ public class ProviderBillingServiceTests
|
||||
var customer = new Customer
|
||||
{
|
||||
Id = "customer_id",
|
||||
Address = new Address
|
||||
{
|
||||
Country = "US"
|
||||
},
|
||||
InvoiceSettings = new CustomerInvoiceSettings(),
|
||||
Metadata = new Dictionary<string, string>
|
||||
{
|
||||
["btCustomerId"] = "braintree_customer_id"
|
||||
},
|
||||
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
|
||||
}
|
||||
};
|
||||
|
||||
sutProvider.GetDependency<ISubscriberService>()
|
||||
@ -1692,22 +1797,92 @@ public class ProviderBillingServiceTests
|
||||
|
||||
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
|
||||
|
||||
sutProvider.GetDependency<IAutomaticTaxStrategy>()
|
||||
.When(x => x.SetCreateOptions(
|
||||
Arg.Is<SubscriptionCreateOptions>(options =>
|
||||
options.Customer == "customer_id")
|
||||
, Arg.Is<Customer>(p => p == customer)))
|
||||
.Do(x =>
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
|
||||
|
||||
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
|
||||
sub =>
|
||||
sub.AutomaticTax.Enabled == true &&
|
||||
sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically &&
|
||||
sub.Customer == "customer_id" &&
|
||||
sub.DaysUntilDue == null &&
|
||||
sub.Items.Count == 2 &&
|
||||
sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams &&
|
||||
sub.Items.ElementAt(0).Quantity == 100 &&
|
||||
sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise &&
|
||||
sub.Items.ElementAt(1).Quantity == 100 &&
|
||||
sub.Metadata["providerId"] == provider.Id.ToString() &&
|
||||
sub.OffSession == true &&
|
||||
sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
|
||||
sub.TrialPeriodDays == 14)).Returns(expected);
|
||||
|
||||
var actual = await sutProvider.Sut.SetupSubscription(provider);
|
||||
|
||||
Assert.Equivalent(expected, actual);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task SetupSubscription_ReverseCharge_Succeeds(
|
||||
SutProvider<ProviderBillingService> sutProvider,
|
||||
Provider provider)
|
||||
{
|
||||
provider.Type = ProviderType.Msp;
|
||||
provider.GatewaySubscriptionId = null;
|
||||
|
||||
var customer = new Customer
|
||||
{
|
||||
Id = "customer_id",
|
||||
Address = new Address { Country = "CA" },
|
||||
InvoiceSettings = new CustomerInvoiceSettings
|
||||
{
|
||||
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = true
|
||||
};
|
||||
});
|
||||
DefaultPaymentMethodId = "pm_123"
|
||||
}
|
||||
};
|
||||
|
||||
sutProvider.GetDependency<ISubscriberService>()
|
||||
.GetCustomerOrThrow(
|
||||
provider,
|
||||
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer);
|
||||
|
||||
var providerPlans = new List<ProviderPlan>
|
||||
{
|
||||
new()
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
ProviderId = provider.Id,
|
||||
PlanType = PlanType.TeamsMonthly,
|
||||
SeatMinimum = 100,
|
||||
PurchasedSeats = 0,
|
||||
AllocatedSeats = 0
|
||||
},
|
||||
new()
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
ProviderId = provider.Id,
|
||||
PlanType = PlanType.EnterpriseMonthly,
|
||||
SeatMinimum = 100,
|
||||
PurchasedSeats = 0,
|
||||
AllocatedSeats = 0
|
||||
}
|
||||
};
|
||||
|
||||
foreach (var plan in providerPlans)
|
||||
{
|
||||
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(plan.PlanType)
|
||||
.Returns(StaticStore.GetPlan(plan.PlanType));
|
||||
}
|
||||
|
||||
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id)
|
||||
.Returns(providerPlans);
|
||||
|
||||
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
|
||||
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
|
||||
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
|
||||
|
||||
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
|
||||
sub =>
|
||||
sub.AutomaticTax.Enabled == true &&
|
||||
|
@ -11,6 +11,7 @@ MAILCATCHER_PORT=1080
|
||||
# Alternative databases
|
||||
POSTGRES_PASSWORD=SET_A_PASSWORD_HERE_123
|
||||
MYSQL_ROOT_PASSWORD=SET_A_PASSWORD_HERE_123
|
||||
MARIADB_ROOT_PASSWORD=SET_A_PASSWORD_HERE_123
|
||||
|
||||
# IdP configuration
|
||||
# Complete using the values from the Manage SSO page in the web vault
|
||||
|
@ -70,6 +70,20 @@ services:
|
||||
profiles:
|
||||
- mysql
|
||||
|
||||
mariadb:
|
||||
image: mariadb:10
|
||||
ports:
|
||||
- 4306:3306
|
||||
environment:
|
||||
MARIADB_USER: maria
|
||||
MARIADB_PASSWORD: ${MARIADB_ROOT_PASSWORD}
|
||||
MARIADB_DATABASE: vault_dev
|
||||
MARIADB_RANDOM_ROOT_PASSWORD: "true"
|
||||
volumes:
|
||||
- mariadb_dev_data:/var/lib/mysql
|
||||
profiles:
|
||||
- mariadb
|
||||
|
||||
idp:
|
||||
image: kenchan0130/simplesamlphp:1.19.8
|
||||
container_name: idp
|
||||
|
@ -5,6 +5,7 @@ param(
|
||||
[switch]$all,
|
||||
[switch]$postgres,
|
||||
[switch]$mysql,
|
||||
[switch]$mariadb,
|
||||
[switch]$mssql,
|
||||
[switch]$sqlite,
|
||||
[switch]$selfhost,
|
||||
@ -15,11 +16,15 @@ param(
|
||||
$ErrorActionPreference = "Stop"
|
||||
$currentDir = Get-Location
|
||||
|
||||
if (!$all -and !$postgres -and !$mysql -and !$sqlite) {
|
||||
function Get-IsEFDatabase {
|
||||
return $postgres -or $mysql -or $mariadb -or $sqlite;
|
||||
}
|
||||
|
||||
if (!$all -and !$(Get-IsEFDatabase)) {
|
||||
$mssql = $true;
|
||||
}
|
||||
|
||||
if ($all -or $postgres -or $mysql -or $sqlite) {
|
||||
if ($all -or $(Get-IsEFDatabase)) {
|
||||
dotnet ef *> $null
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Host "Entity Framework Core tools were not found in the dotnet global tools. Attempting to install"
|
||||
@ -60,9 +65,12 @@ if ($all -or $mssql) {
|
||||
}
|
||||
|
||||
Foreach ($item in @(
|
||||
@($mysql, "MySQL", "MySqlMigrations", "mySql", 2),
|
||||
@($postgres, "PostgreSQL", "PostgresMigrations", "postgreSql", 0),
|
||||
@($sqlite, "SQLite", "SqliteMigrations", "sqlite", 1)
|
||||
@($sqlite, "SQLite", "SqliteMigrations", "sqlite", 1),
|
||||
@($mysql, "MySQL", "MySqlMigrations", "mySql", 2),
|
||||
# MariaDB shares the MySQL connection string in the server config so they are mutually exclusive in that context.
|
||||
# However they can still be run independently for integration tests.
|
||||
@($mariadb, "MariaDB", "MySqlMigrations", "mySql", 3)
|
||||
)) {
|
||||
if (!$item[0] -and !$all) {
|
||||
continue
|
||||
|
@ -40,8 +40,6 @@ export function authenticate(
|
||||
payload["deviceName"] = "chrome";
|
||||
payload["username"] = username;
|
||||
payload["password"] = password;
|
||||
|
||||
params.headers["Auth-Email"] = encoding.b64encode(username);
|
||||
} else {
|
||||
payload["scope"] = "api.organization";
|
||||
payload["grant_type"] = "client_credentials";
|
||||
|
@ -69,7 +69,7 @@
|
||||
document.getElementById('@(nameof(Model.UseGroups))').checked = plan.hasGroups;
|
||||
document.getElementById('@(nameof(Model.UsePolicies))').checked = plan.hasPolicies;
|
||||
document.getElementById('@(nameof(Model.UseSso))').checked = plan.hasSso;
|
||||
document.getElementById('@(nameof(Model.UseOrganizationDomains))').checked = hasOrganizationDomains;
|
||||
document.getElementById('@(nameof(Model.UseOrganizationDomains))').checked = plan.hasOrganizationDomains;
|
||||
document.getElementById('@(nameof(Model.UseScim))').checked = plan.hasScim;
|
||||
document.getElementById('@(nameof(Model.UseDirectory))').checked = plan.hasDirectory;
|
||||
document.getElementById('@(nameof(Model.UseEvents))').checked = plan.hasEvents;
|
||||
|
@ -2,7 +2,7 @@
|
||||
using Bit.Core;
|
||||
using Bit.Core.Jobs;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Bit.Core.Tools.Services;
|
||||
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
|
||||
using Quartz;
|
||||
|
||||
namespace Bit.Admin.Tools.Jobs;
|
||||
@ -32,10 +32,10 @@ public class DeleteSendsJob : BaseJob
|
||||
}
|
||||
using (var scope = _serviceProvider.CreateScope())
|
||||
{
|
||||
var sendService = scope.ServiceProvider.GetRequiredService<ISendService>();
|
||||
var nonAnonymousSendCommand = scope.ServiceProvider.GetRequiredService<INonAnonymousSendCommand>();
|
||||
foreach (var send in sends)
|
||||
{
|
||||
await sendService.DeleteSendAsync(send);
|
||||
await nonAnonymousSendCommand.DeleteSendAsync(send);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,13 +2,11 @@
|
||||
using Bit.Api.AdminConsole.Models.Request.Organizations;
|
||||
using Bit.Api.AdminConsole.Models.Response.Organizations;
|
||||
using Bit.Api.Models.Response;
|
||||
using Bit.Core;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Utilities;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
using Microsoft.AspNetCore.Mvc;
|
||||
|
||||
@ -137,7 +135,6 @@ public class OrganizationDomainController : Controller
|
||||
|
||||
[AllowAnonymous]
|
||||
[HttpPost("domain/sso/verified")]
|
||||
[RequireFeature(FeatureFlagKeys.VerifiedSsoDomainEndpoint)]
|
||||
public async Task<VerifiedOrganizationDomainSsoDetailsResponseModel> GetVerifiedOrgDomainSsoDetailsAsync(
|
||||
[FromBody] OrganizationDomainSsoDetailsRequestModel model)
|
||||
{
|
||||
|
@ -292,15 +292,17 @@ public class OrganizationBillingController(
|
||||
sale.Organization.PlanType = plan.Type;
|
||||
sale.Organization.Plan = plan.Name;
|
||||
sale.SubscriptionSetup.SkipTrial = true;
|
||||
await organizationBillingService.Finalize(sale);
|
||||
|
||||
if (organizationSignup.PaymentMethodType == null || string.IsNullOrEmpty(organizationSignup.PaymentToken))
|
||||
{
|
||||
return Error.BadRequest("A payment method is required to restart the subscription.");
|
||||
}
|
||||
var org = await organizationRepository.GetByIdAsync(organizationId);
|
||||
Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine.");
|
||||
if (organizationSignup.PaymentMethodType != null)
|
||||
{
|
||||
var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken);
|
||||
var taxInformation = TaxInformation.From(organizationSignup.TaxInfo);
|
||||
await organizationBillingService.UpdatePaymentMethod(org, paymentSource, taxInformation);
|
||||
}
|
||||
var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken);
|
||||
var taxInformation = TaxInformation.From(organizationSignup.TaxInfo);
|
||||
await organizationBillingService.UpdatePaymentMethod(org, paymentSource, taxInformation);
|
||||
await organizationBillingService.Finalize(sale);
|
||||
|
||||
return TypedResults.Ok();
|
||||
}
|
||||
|
@ -207,7 +207,7 @@ public class OrganizationSponsorshipsController : Controller
|
||||
[HttpDelete("{sponsoringOrganizationId}")]
|
||||
[HttpPost("{sponsoringOrganizationId}/delete")]
|
||||
[SelfHosted(NotSelfHostedOnly = true)]
|
||||
public async Task RevokeSponsorship(Guid sponsoringOrganizationId, [FromQuery] bool isAdminInitiated = false)
|
||||
public async Task RevokeSponsorship(Guid sponsoringOrganizationId)
|
||||
{
|
||||
|
||||
var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrganizationId, _currentContext.UserId ?? default);
|
||||
@ -217,11 +217,25 @@ public class OrganizationSponsorshipsController : Controller
|
||||
}
|
||||
|
||||
var existingOrgSponsorship = await _organizationSponsorshipRepository
|
||||
.GetBySponsoringOrganizationUserIdAsync(orgUser.Id, isAdminInitiated);
|
||||
.GetBySponsoringOrganizationUserIdAsync(orgUser.Id);
|
||||
|
||||
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
|
||||
}
|
||||
|
||||
[Authorize("Application")]
|
||||
[HttpDelete("{sponsoringOrgId}/{sponsoredFriendlyName}/revoke")]
|
||||
[SelfHosted(NotSelfHostedOnly = true)]
|
||||
public async Task AdminInitiatedRevokeSponsorshipAsync(Guid sponsoringOrgId, string sponsoredFriendlyName)
|
||||
{
|
||||
var sponsorships = await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrgId);
|
||||
var existingOrgSponsorship = sponsorships.FirstOrDefault(s => s.FriendlyName != null && s.FriendlyName.Equals(sponsoredFriendlyName, StringComparison.OrdinalIgnoreCase));
|
||||
if (existingOrgSponsorship == null)
|
||||
{
|
||||
throw new BadRequestException("The specified sponsored organization could not be found under the given sponsoring organization.");
|
||||
}
|
||||
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
|
||||
}
|
||||
|
||||
[Authorize("Application")]
|
||||
[HttpDelete("sponsored/{sponsoredOrgId}")]
|
||||
[HttpPost("sponsored/{sponsoredOrgId}/remove")]
|
||||
|
@ -109,28 +109,6 @@ public class OrganizationsController(
|
||||
return license;
|
||||
}
|
||||
|
||||
[HttpPost("{id:guid}/payment")]
|
||||
[SelfHosted(NotSelfHostedOnly = true)]
|
||||
public async Task PostPayment(Guid id, [FromBody] PaymentRequestModel model)
|
||||
{
|
||||
if (!await currentContext.EditPaymentMethods(id))
|
||||
{
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
await organizationService.ReplacePaymentMethodAsync(id, model.PaymentToken,
|
||||
model.PaymentMethodType.Value, new TaxInfo
|
||||
{
|
||||
BillingAddressLine1 = model.Line1,
|
||||
BillingAddressLine2 = model.Line2,
|
||||
BillingAddressState = model.State,
|
||||
BillingAddressCity = model.City,
|
||||
BillingAddressPostalCode = model.PostalCode,
|
||||
BillingAddressCountry = model.Country,
|
||||
TaxIdNumber = model.TaxId,
|
||||
});
|
||||
}
|
||||
|
||||
[HttpPost("{id:guid}/upgrade")]
|
||||
[SelfHosted(NotSelfHostedOnly = true)]
|
||||
public async Task<PaymentResponseModel> PostUpgrade(Guid id, [FromBody] OrganizationUpgradeRequestModel model)
|
||||
|
@ -1,6 +1,10 @@
|
||||
using Bit.Api.Models.Request.Organizations;
|
||||
using Bit.Api.AdminConsole.Authorization.Requirements;
|
||||
using Bit.Api.Models.Request.Organizations;
|
||||
using Bit.Api.Models.Response;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Models.Api.Response.OrganizationSponsorships;
|
||||
using Bit.Core.Models.Data.Organizations.OrganizationSponsorships;
|
||||
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
@ -22,6 +26,7 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
|
||||
private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand;
|
||||
private readonly ICurrentContext _currentContext;
|
||||
private readonly IFeatureService _featureService;
|
||||
private readonly IAuthorizationService _authorizationService;
|
||||
|
||||
public SelfHostedOrganizationSponsorshipsController(
|
||||
ICreateSponsorshipCommand offerSponsorshipCommand,
|
||||
@ -30,7 +35,8 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
|
||||
IOrganizationSponsorshipRepository organizationSponsorshipRepository,
|
||||
IOrganizationUserRepository organizationUserRepository,
|
||||
ICurrentContext currentContext,
|
||||
IFeatureService featureService
|
||||
IFeatureService featureService,
|
||||
IAuthorizationService authorizationService
|
||||
)
|
||||
{
|
||||
_offerSponsorshipCommand = offerSponsorshipCommand;
|
||||
@ -40,6 +46,7 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
|
||||
_organizationUserRepository = organizationUserRepository;
|
||||
_currentContext = currentContext;
|
||||
_featureService = featureService;
|
||||
_authorizationService = authorizationService;
|
||||
}
|
||||
|
||||
[HttpPost("{sponsoringOrgId}/families-for-enterprise")]
|
||||
@ -84,4 +91,41 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
|
||||
|
||||
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
|
||||
}
|
||||
|
||||
[HttpDelete("{sponsoringOrgId}/{sponsoredFriendlyName}/revoke")]
|
||||
public async Task AdminInitiatedRevokeSponsorshipAsync(Guid sponsoringOrgId, string sponsoredFriendlyName)
|
||||
{
|
||||
var sponsorships = await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrgId);
|
||||
var existingOrgSponsorship = sponsorships.FirstOrDefault(s => s.FriendlyName != null && s.FriendlyName.Equals(sponsoredFriendlyName, StringComparison.OrdinalIgnoreCase));
|
||||
if (existingOrgSponsorship == null)
|
||||
{
|
||||
throw new BadRequestException("The specified sponsored organization could not be found under the given sponsoring organization.");
|
||||
}
|
||||
await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship);
|
||||
}
|
||||
|
||||
[Authorize("Application")]
|
||||
[HttpGet("{orgId}/sponsored")]
|
||||
public async Task<ListResponseModel<OrganizationSponsorshipInvitesResponseModel>> GetSponsoredOrganizations(Guid orgId)
|
||||
{
|
||||
var sponsoringOrg = await _organizationRepository.GetByIdAsync(orgId);
|
||||
if (sponsoringOrg == null)
|
||||
{
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
var authorizationResult = await _authorizationService.AuthorizeAsync(User, orgId, new ManageUsersRequirement());
|
||||
if (!authorizationResult.Succeeded)
|
||||
{
|
||||
throw new UnauthorizedAccessException();
|
||||
}
|
||||
|
||||
var sponsorships = await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(orgId);
|
||||
return new ListResponseModel<OrganizationSponsorshipInvitesResponseModel>(
|
||||
sponsorships
|
||||
.Where(s => s.IsAdminInitiated)
|
||||
.Select(s => new OrganizationSponsorshipInvitesResponseModel(new OrganizationSponsorshipData(s)))
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -12,17 +12,17 @@ namespace Bit.Api.KeyManagement.Validators;
|
||||
/// </summary>
|
||||
public class SendRotationValidator : IRotationValidator<IEnumerable<SendWithIdRequestModel>, IReadOnlyList<Send>>
|
||||
{
|
||||
private readonly ISendService _sendService;
|
||||
private readonly ISendAuthorizationService _sendAuthorizationService;
|
||||
private readonly ISendRepository _sendRepository;
|
||||
|
||||
/// <summary>
|
||||
/// Instantiates a new <see cref="SendRotationValidator"/>
|
||||
/// </summary>
|
||||
/// <param name="sendService">Enables conversion of <see cref="SendWithIdRequestModel"/> to <see cref="Send"/></param>
|
||||
/// <param name="sendAuthorizationService">Enables conversion of <see cref="SendWithIdRequestModel"/> to <see cref="Send"/></param>
|
||||
/// <param name="sendRepository">Retrieves all user <see cref="Send"/>s</param>
|
||||
public SendRotationValidator(ISendService sendService, ISendRepository sendRepository)
|
||||
public SendRotationValidator(ISendAuthorizationService sendAuthorizationService, ISendRepository sendRepository)
|
||||
{
|
||||
_sendService = sendService;
|
||||
_sendAuthorizationService = sendAuthorizationService;
|
||||
_sendRepository = sendRepository;
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ public class SendRotationValidator : IRotationValidator<IEnumerable<SendWithIdRe
|
||||
throw new BadRequestException("All existing sends must be included in the rotation.");
|
||||
}
|
||||
|
||||
result.Add(send.ToSend(existing, _sendService));
|
||||
result.Add(send.ToSend(existing, _sendAuthorizationService));
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -35,6 +35,7 @@ using Bit.Core.Services;
|
||||
using Bit.Core.Tools.ImportFeatures;
|
||||
using Bit.Core.Tools.ReportFeatures;
|
||||
using Bit.Core.Auth.Models.Api.Request;
|
||||
using Bit.Core.Tools.SendFeatures;
|
||||
|
||||
#if !OSS
|
||||
using Bit.Commercial.Core.SecretsManager;
|
||||
@ -186,6 +187,7 @@ public class Startup
|
||||
services.AddPhishingDomainServices(globalSettings);
|
||||
|
||||
services.AddBillingQueries();
|
||||
services.AddSendServices();
|
||||
|
||||
// Authorization Handlers
|
||||
services.AddAuthorizationHandlers();
|
||||
|
@ -12,6 +12,8 @@ using Bit.Core.Settings;
|
||||
using Bit.Core.Tools.Enums;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Bit.Core.Tools.SendFeatures;
|
||||
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
|
||||
using Bit.Core.Tools.Services;
|
||||
using Bit.Core.Utilities;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
@ -25,8 +27,10 @@ public class SendsController : Controller
|
||||
{
|
||||
private readonly ISendRepository _sendRepository;
|
||||
private readonly IUserService _userService;
|
||||
private readonly ISendService _sendService;
|
||||
private readonly ISendAuthorizationService _sendAuthorizationService;
|
||||
private readonly ISendFileStorageService _sendFileStorageService;
|
||||
private readonly IAnonymousSendCommand _anonymousSendCommand;
|
||||
private readonly INonAnonymousSendCommand _nonAnonymousSendCommand;
|
||||
private readonly ILogger<SendsController> _logger;
|
||||
private readonly GlobalSettings _globalSettings;
|
||||
private readonly ICurrentContext _currentContext;
|
||||
@ -34,7 +38,9 @@ public class SendsController : Controller
|
||||
public SendsController(
|
||||
ISendRepository sendRepository,
|
||||
IUserService userService,
|
||||
ISendService sendService,
|
||||
ISendAuthorizationService sendAuthorizationService,
|
||||
IAnonymousSendCommand anonymousSendCommand,
|
||||
INonAnonymousSendCommand nonAnonymousSendCommand,
|
||||
ISendFileStorageService sendFileStorageService,
|
||||
ILogger<SendsController> logger,
|
||||
GlobalSettings globalSettings,
|
||||
@ -42,13 +48,16 @@ public class SendsController : Controller
|
||||
{
|
||||
_sendRepository = sendRepository;
|
||||
_userService = userService;
|
||||
_sendService = sendService;
|
||||
_sendAuthorizationService = sendAuthorizationService;
|
||||
_anonymousSendCommand = anonymousSendCommand;
|
||||
_nonAnonymousSendCommand = nonAnonymousSendCommand;
|
||||
_sendFileStorageService = sendFileStorageService;
|
||||
_logger = logger;
|
||||
_globalSettings = globalSettings;
|
||||
_currentContext = currentContext;
|
||||
}
|
||||
|
||||
#region Anonymous endpoints
|
||||
[AllowAnonymous]
|
||||
[HttpPost("access/{id}")]
|
||||
public async Task<IActionResult> Access(string id, [FromBody] SendAccessRequestModel model)
|
||||
@ -61,18 +70,19 @@ public class SendsController : Controller
|
||||
//}
|
||||
|
||||
var guid = new Guid(CoreHelpers.Base64UrlDecode(id));
|
||||
var (send, passwordRequired, passwordInvalid) =
|
||||
await _sendService.AccessAsync(guid, model.Password);
|
||||
if (passwordRequired)
|
||||
var send = await _sendRepository.GetByIdAsync(guid);
|
||||
SendAccessResult sendAuthResult =
|
||||
await _sendAuthorizationService.AccessAsync(send, model.Password);
|
||||
if (sendAuthResult.Equals(SendAccessResult.PasswordRequired))
|
||||
{
|
||||
return new UnauthorizedResult();
|
||||
}
|
||||
if (passwordInvalid)
|
||||
if (sendAuthResult.Equals(SendAccessResult.PasswordInvalid))
|
||||
{
|
||||
await Task.Delay(2000);
|
||||
throw new BadRequestException("Invalid password.");
|
||||
}
|
||||
if (send == null)
|
||||
if (sendAuthResult.Equals(SendAccessResult.Denied))
|
||||
{
|
||||
throw new NotFoundException();
|
||||
}
|
||||
@ -106,19 +116,19 @@ public class SendsController : Controller
|
||||
throw new BadRequestException("Could not locate send");
|
||||
}
|
||||
|
||||
var (url, passwordRequired, passwordInvalid) = await _sendService.GetSendFileDownloadUrlAsync(send, fileId,
|
||||
var (url, result) = await _anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId,
|
||||
model.Password);
|
||||
|
||||
if (passwordRequired)
|
||||
if (result.Equals(SendAccessResult.PasswordRequired))
|
||||
{
|
||||
return new UnauthorizedResult();
|
||||
}
|
||||
if (passwordInvalid)
|
||||
if (result.Equals(SendAccessResult.PasswordInvalid))
|
||||
{
|
||||
await Task.Delay(2000);
|
||||
throw new BadRequestException("Invalid password.");
|
||||
}
|
||||
if (send == null)
|
||||
if (result.Equals(SendAccessResult.Denied))
|
||||
{
|
||||
throw new NotFoundException();
|
||||
}
|
||||
@ -130,6 +140,45 @@ public class SendsController : Controller
|
||||
});
|
||||
}
|
||||
|
||||
[AllowAnonymous]
|
||||
[HttpPost("file/validate/azure")]
|
||||
public async Task<ObjectResult> AzureValidateFile()
|
||||
{
|
||||
return await ApiHelpers.HandleAzureEvents(Request, new Dictionary<string, Func<EventGridEvent, Task>>
|
||||
{
|
||||
{
|
||||
"Microsoft.Storage.BlobCreated", async (eventGridEvent) =>
|
||||
{
|
||||
try
|
||||
{
|
||||
var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1];
|
||||
var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName);
|
||||
var send = await _sendRepository.GetByIdAsync(new Guid(sendId));
|
||||
if (send == null)
|
||||
{
|
||||
if (_sendFileStorageService is AzureSendFileStorageService azureSendFileStorageService)
|
||||
{
|
||||
await azureSendFileStorageService.DeleteBlobAsync(blobName);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
await _nonAnonymousSendCommand.ConfirmFileSize(send);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
_logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Non-anonymous endpoints
|
||||
|
||||
[HttpGet("{id}")]
|
||||
public async Task<SendResponseModel> Get(string id)
|
||||
{
|
||||
@ -157,8 +206,8 @@ public class SendsController : Controller
|
||||
{
|
||||
model.ValidateCreation();
|
||||
var userId = _userService.GetProperUserId(User).Value;
|
||||
var send = model.ToSend(userId, _sendService);
|
||||
await _sendService.SaveSendAsync(send);
|
||||
var send = model.ToSend(userId, _sendAuthorizationService);
|
||||
await _nonAnonymousSendCommand.SaveSendAsync(send);
|
||||
return new SendResponseModel(send, _globalSettings);
|
||||
}
|
||||
|
||||
@ -175,15 +224,15 @@ public class SendsController : Controller
|
||||
throw new BadRequestException("Invalid content. File size hint is required.");
|
||||
}
|
||||
|
||||
if (model.FileLength.Value > SendService.MAX_FILE_SIZE)
|
||||
if (model.FileLength.Value > Constants.FileSize501mb)
|
||||
{
|
||||
throw new BadRequestException($"Max file size is {SendService.MAX_FILE_SIZE_READABLE}.");
|
||||
throw new BadRequestException($"Max file size is {SendFileSettingHelper.MAX_FILE_SIZE_READABLE}.");
|
||||
}
|
||||
|
||||
model.ValidateCreation();
|
||||
var userId = _userService.GetProperUserId(User).Value;
|
||||
var (send, data) = model.ToSend(userId, model.File.FileName, _sendService);
|
||||
var uploadUrl = await _sendService.SaveFileSendAsync(send, data, model.FileLength.Value);
|
||||
var (send, data) = model.ToSend(userId, model.File.FileName, _sendAuthorizationService);
|
||||
var uploadUrl = await _nonAnonymousSendCommand.SaveFileSendAsync(send, data, model.FileLength.Value);
|
||||
return new SendFileUploadDataResponseModel
|
||||
{
|
||||
Url = uploadUrl,
|
||||
@ -230,41 +279,7 @@ public class SendsController : Controller
|
||||
var send = await _sendRepository.GetByIdAsync(new Guid(id));
|
||||
await Request.GetFileAsync(async (stream) =>
|
||||
{
|
||||
await _sendService.UploadFileToExistingSendAsync(stream, send);
|
||||
});
|
||||
}
|
||||
|
||||
[AllowAnonymous]
|
||||
[HttpPost("file/validate/azure")]
|
||||
public async Task<ObjectResult> AzureValidateFile()
|
||||
{
|
||||
return await ApiHelpers.HandleAzureEvents(Request, new Dictionary<string, Func<EventGridEvent, Task>>
|
||||
{
|
||||
{
|
||||
"Microsoft.Storage.BlobCreated", async (eventGridEvent) =>
|
||||
{
|
||||
try
|
||||
{
|
||||
var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1];
|
||||
var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName);
|
||||
var send = await _sendRepository.GetByIdAsync(new Guid(sendId));
|
||||
if (send == null)
|
||||
{
|
||||
if (_sendFileStorageService is AzureSendFileStorageService azureSendFileStorageService)
|
||||
{
|
||||
await azureSendFileStorageService.DeleteBlobAsync(blobName);
|
||||
}
|
||||
return;
|
||||
}
|
||||
await _sendService.ValidateSendFile(send);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
_logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
await _nonAnonymousSendCommand.UploadFileToExistingSendAsync(stream, send);
|
||||
});
|
||||
}
|
||||
|
||||
@ -279,7 +294,7 @@ public class SendsController : Controller
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
await _sendService.SaveSendAsync(model.ToSend(send, _sendService));
|
||||
await _nonAnonymousSendCommand.SaveSendAsync(model.ToSend(send, _sendAuthorizationService));
|
||||
return new SendResponseModel(send, _globalSettings);
|
||||
}
|
||||
|
||||
@ -294,7 +309,7 @@ public class SendsController : Controller
|
||||
}
|
||||
|
||||
send.Password = null;
|
||||
await _sendService.SaveSendAsync(send);
|
||||
await _nonAnonymousSendCommand.SaveSendAsync(send);
|
||||
return new SendResponseModel(send, _globalSettings);
|
||||
}
|
||||
|
||||
@ -308,6 +323,8 @@ public class SendsController : Controller
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
await _sendService.DeleteSendAsync(send);
|
||||
await _nonAnonymousSendCommand.DeleteSendAsync(send);
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
|
@ -36,31 +36,31 @@ public class SendRequestModel
|
||||
public bool? Disabled { get; set; }
|
||||
public bool? HideEmail { get; set; }
|
||||
|
||||
public Send ToSend(Guid userId, ISendService sendService)
|
||||
public Send ToSend(Guid userId, ISendAuthorizationService sendAuthorizationService)
|
||||
{
|
||||
var send = new Send
|
||||
{
|
||||
Type = Type,
|
||||
UserId = (Guid?)userId
|
||||
};
|
||||
ToSend(send, sendService);
|
||||
ToSend(send, sendAuthorizationService);
|
||||
return send;
|
||||
}
|
||||
|
||||
public (Send, SendFileData) ToSend(Guid userId, string fileName, ISendService sendService)
|
||||
public (Send, SendFileData) ToSend(Guid userId, string fileName, ISendAuthorizationService sendAuthorizationService)
|
||||
{
|
||||
var send = ToSendBase(new Send
|
||||
{
|
||||
Type = Type,
|
||||
UserId = (Guid?)userId
|
||||
}, sendService);
|
||||
}, sendAuthorizationService);
|
||||
var data = new SendFileData(Name, Notes, fileName);
|
||||
return (send, data);
|
||||
}
|
||||
|
||||
public Send ToSend(Send existingSend, ISendService sendService)
|
||||
public Send ToSend(Send existingSend, ISendAuthorizationService sendAuthorizationService)
|
||||
{
|
||||
existingSend = ToSendBase(existingSend, sendService);
|
||||
existingSend = ToSendBase(existingSend, sendAuthorizationService);
|
||||
switch (existingSend.Type)
|
||||
{
|
||||
case SendType.File:
|
||||
@ -125,7 +125,7 @@ public class SendRequestModel
|
||||
}
|
||||
}
|
||||
|
||||
private Send ToSendBase(Send existingSend, ISendService sendService)
|
||||
private Send ToSendBase(Send existingSend, ISendAuthorizationService authorizationService)
|
||||
{
|
||||
existingSend.Key = Key;
|
||||
existingSend.ExpirationDate = ExpirationDate;
|
||||
@ -133,7 +133,7 @@ public class SendRequestModel
|
||||
existingSend.MaxAccessCount = MaxAccessCount;
|
||||
if (!string.IsNullOrWhiteSpace(Password))
|
||||
{
|
||||
existingSend.Password = sendService.HashPassword(Password);
|
||||
existingSend.Password = authorizationService.HashPassword(Password);
|
||||
}
|
||||
existingSend.Disabled = Disabled.GetValueOrDefault();
|
||||
existingSend.HideEmail = HideEmail.GetValueOrDefault();
|
||||
|
@ -1,11 +1,11 @@
|
||||
using Bit.Core;
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Extensions;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Billing.Services.Contracts;
|
||||
using Bit.Core.Billing.Tax.Services;
|
||||
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
@ -25,8 +25,7 @@ public class UpcomingInvoiceHandler(
|
||||
IStripeEventService stripeEventService,
|
||||
IStripeEventUtilityService stripeEventUtilityService,
|
||||
IUserRepository userRepository,
|
||||
IValidateSponsorshipCommand validateSponsorshipCommand,
|
||||
IAutomaticTaxFactory automaticTaxFactory)
|
||||
IValidateSponsorshipCommand validateSponsorshipCommand)
|
||||
: IUpcomingInvoiceHandler
|
||||
{
|
||||
public async Task HandleAsync(Event parsedEvent)
|
||||
@ -46,6 +45,8 @@ public class UpcomingInvoiceHandler(
|
||||
|
||||
var (organizationId, userId, providerId) = stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata);
|
||||
|
||||
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
|
||||
|
||||
if (organizationId.HasValue)
|
||||
{
|
||||
var organization = await organizationRepository.GetByIdAsync(organizationId.Value);
|
||||
@ -55,7 +56,7 @@ public class UpcomingInvoiceHandler(
|
||||
return;
|
||||
}
|
||||
|
||||
await TryEnableAutomaticTaxAsync(subscription);
|
||||
await AlignOrganizationTaxConcernsAsync(organization, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge);
|
||||
|
||||
var plan = await pricingClient.GetPlanOrThrow(organization.PlanType);
|
||||
|
||||
@ -100,7 +101,25 @@ public class UpcomingInvoiceHandler(
|
||||
return;
|
||||
}
|
||||
|
||||
await TryEnableAutomaticTaxAsync(subscription);
|
||||
if (!subscription.AutomaticTax.Enabled && subscription.Customer.HasRecognizedTaxLocation())
|
||||
{
|
||||
try
|
||||
{
|
||||
await stripeFacade.UpdateSubscription(subscription.Id,
|
||||
new SubscriptionUpdateOptions
|
||||
{
|
||||
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
|
||||
});
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
logger.LogError(
|
||||
exception,
|
||||
"Failed to set user's ({UserID}) subscription to automatic tax while processing event with ID {EventID}",
|
||||
user.Id,
|
||||
parsedEvent.Id);
|
||||
}
|
||||
}
|
||||
|
||||
if (user.Premium)
|
||||
{
|
||||
@ -116,7 +135,7 @@ public class UpcomingInvoiceHandler(
|
||||
return;
|
||||
}
|
||||
|
||||
await TryEnableAutomaticTaxAsync(subscription);
|
||||
await AlignProviderTaxConcernsAsync(provider, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge);
|
||||
|
||||
await SendUpcomingInvoiceEmailsAsync(new List<string> { provider.BillingEmail }, invoice);
|
||||
}
|
||||
@ -139,50 +158,123 @@ public class UpcomingInvoiceHandler(
|
||||
}
|
||||
}
|
||||
|
||||
private async Task TryEnableAutomaticTaxAsync(Subscription subscription)
|
||||
private async Task AlignOrganizationTaxConcernsAsync(
|
||||
Organization organization,
|
||||
Subscription subscription,
|
||||
string eventId,
|
||||
bool setNonUSBusinessUseToReverseCharge)
|
||||
{
|
||||
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
|
||||
{
|
||||
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id));
|
||||
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
|
||||
var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
|
||||
var nonUSBusinessUse =
|
||||
organization.PlanType.GetProductTier() != ProductTierType.Families &&
|
||||
subscription.Customer.Address.Country != "US";
|
||||
|
||||
if (updateOptions == null)
|
||||
bool setAutomaticTaxToEnabled;
|
||||
|
||||
if (setNonUSBusinessUseToReverseCharge)
|
||||
{
|
||||
if (nonUSBusinessUse && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse)
|
||||
{
|
||||
return;
|
||||
try
|
||||
{
|
||||
await stripeFacade.UpdateCustomer(subscription.CustomerId,
|
||||
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
logger.LogError(
|
||||
exception,
|
||||
"Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}",
|
||||
organization.Id,
|
||||
eventId);
|
||||
}
|
||||
}
|
||||
|
||||
await stripeFacade.UpdateSubscription(subscription.Id, updateOptions);
|
||||
return;
|
||||
setAutomaticTaxToEnabled = true;
|
||||
}
|
||||
|
||||
if (subscription.AutomaticTax.Enabled ||
|
||||
!subscription.Customer.HasBillingLocation() ||
|
||||
await IsNonTaxableNonUSBusinessUseSubscription(subscription))
|
||||
else
|
||||
{
|
||||
return;
|
||||
setAutomaticTaxToEnabled =
|
||||
subscription.Customer.HasRecognizedTaxLocation() &&
|
||||
(subscription.Customer.Address.Country == "US" ||
|
||||
(nonUSBusinessUse && subscription.Customer.TaxIds.Any()));
|
||||
}
|
||||
|
||||
await stripeFacade.UpdateSubscription(subscription.Id,
|
||||
new SubscriptionUpdateOptions
|
||||
if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled)
|
||||
{
|
||||
try
|
||||
{
|
||||
DefaultTaxRates = [],
|
||||
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
|
||||
});
|
||||
await stripeFacade.UpdateSubscription(subscription.Id,
|
||||
new SubscriptionUpdateOptions
|
||||
{
|
||||
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
|
||||
});
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
logger.LogError(
|
||||
exception,
|
||||
"Failed to set organization's ({OrganizationID}) subscription to automatic tax while processing event with ID {EventID}",
|
||||
organization.Id,
|
||||
eventId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
private async Task AlignProviderTaxConcernsAsync(
|
||||
Provider provider,
|
||||
Subscription subscription,
|
||||
string eventId,
|
||||
bool setNonUSBusinessUseToReverseCharge)
|
||||
{
|
||||
bool setAutomaticTaxToEnabled;
|
||||
|
||||
async Task<bool> IsNonTaxableNonUSBusinessUseSubscription(Subscription localSubscription)
|
||||
if (setNonUSBusinessUseToReverseCharge)
|
||||
{
|
||||
var familyPriceIds = (await Task.WhenAll(
|
||||
pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019),
|
||||
pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually)))
|
||||
.Select(plan => plan.PasswordManager.StripePlanId);
|
||||
if (subscription.Customer.Address.Country != "US" && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse)
|
||||
{
|
||||
try
|
||||
{
|
||||
await stripeFacade.UpdateCustomer(subscription.CustomerId,
|
||||
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
logger.LogError(
|
||||
exception,
|
||||
"Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}",
|
||||
provider.Id,
|
||||
eventId);
|
||||
}
|
||||
}
|
||||
|
||||
return localSubscription.Customer.Address.Country != "US" &&
|
||||
localSubscription.Metadata.ContainsKey(StripeConstants.MetadataKeys.OrganizationId) &&
|
||||
!localSubscription.Items.Select(item => item.Price.Id).Intersect(familyPriceIds).Any() &&
|
||||
!localSubscription.Customer.TaxIds.Any();
|
||||
setAutomaticTaxToEnabled = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
setAutomaticTaxToEnabled =
|
||||
subscription.Customer.HasRecognizedTaxLocation() &&
|
||||
(subscription.Customer.Address.Country == "US" ||
|
||||
subscription.Customer.TaxIds.Any());
|
||||
}
|
||||
|
||||
if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled)
|
||||
{
|
||||
try
|
||||
{
|
||||
await stripeFacade.UpdateSubscription(subscription.Id,
|
||||
new SubscriptionUpdateOptions
|
||||
{
|
||||
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
|
||||
});
|
||||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
logger.LogError(
|
||||
exception,
|
||||
"Failed to set provider's ({ProviderID}) subscription to automatic tax while processing event with ID {EventID}",
|
||||
provider.Id,
|
||||
eventId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -325,5 +325,6 @@ public class Organization : ITableObject<Guid>, IStorableSubscriber, IRevisable,
|
||||
SmServiceAccounts = license.SmServiceAccounts;
|
||||
UseRiskInsights = license.UseRiskInsights;
|
||||
UseOrganizationDomains = license.UseOrganizationDomains;
|
||||
UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies;
|
||||
}
|
||||
}
|
||||
|
@ -150,6 +150,7 @@ public class SelfHostedOrganizationDetails : Organization
|
||||
AllowAdminAccessToAllCollectionItems = AllowAdminAccessToAllCollectionItems,
|
||||
Status = Status,
|
||||
UseRiskInsights = UseRiskInsights,
|
||||
UseAdminSponsoredFamilies = UseAdminSponsoredFamilies,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -24,9 +24,7 @@ public class GetOrganizationUsersClaimedStatusQuery : IGetOrganizationUsersClaim
|
||||
// Users can only be claimed by an Organization that is enabled and can have organization domains
|
||||
var organizationAbility = await _applicationCacheService.GetOrganizationAbilityAsync(organizationId);
|
||||
|
||||
// TODO: Replace "UseSso" with a new organization ability like "UseOrganizationDomains" (PM-11622).
|
||||
// Verified domains were tied to SSO, so we currently check the "UseSso" organization ability.
|
||||
if (organizationAbility is { Enabled: true, UseSso: true })
|
||||
if (organizationAbility is { Enabled: true, UseOrganizationDomains: true })
|
||||
{
|
||||
// Get all organization users with claimed domains by the organization
|
||||
var organizationUsersWithClaimedDomain = await _organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organizationId);
|
||||
|
@ -0,0 +1,187 @@
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Models.Business;
|
||||
using Bit.Core.Models.StaticStore;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Tools.Enums;
|
||||
using Bit.Core.Tools.Models.Business;
|
||||
using Bit.Core.Tools.Services;
|
||||
using Bit.Core.Utilities;
|
||||
|
||||
namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
|
||||
|
||||
public record ProviderClientOrganizationSignUpResponse(
|
||||
Organization Organization,
|
||||
Collection DefaultCollection);
|
||||
|
||||
public interface IProviderClientOrganizationSignUpCommand
|
||||
{
|
||||
/// <summary>
|
||||
/// Sign up a new client organization for a provider.
|
||||
/// </summary>
|
||||
/// <param name="signup">The signup information.</param>
|
||||
/// <returns>A tuple containing the new organization and its default collection.</returns>
|
||||
Task<ProviderClientOrganizationSignUpResponse> SignUpClientOrganizationAsync(OrganizationSignup signup);
|
||||
}
|
||||
|
||||
public class ProviderClientOrganizationSignUpCommand : IProviderClientOrganizationSignUpCommand
|
||||
{
|
||||
public const string PlanNullErrorMessage = "Password Manager Plan was null.";
|
||||
public const string PlanDisabledErrorMessage = "Password Manager Plan is disabled.";
|
||||
public const string AdditionalSeatsNegativeErrorMessage = "You can't subtract Password Manager seats!";
|
||||
|
||||
private readonly ICurrentContext _currentContext;
|
||||
private readonly IPricingClient _pricingClient;
|
||||
private readonly IReferenceEventService _referenceEventService;
|
||||
private readonly IOrganizationRepository _organizationRepository;
|
||||
private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository;
|
||||
private readonly IApplicationCacheService _applicationCacheService;
|
||||
private readonly ICollectionRepository _collectionRepository;
|
||||
|
||||
public ProviderClientOrganizationSignUpCommand(
|
||||
ICurrentContext currentContext,
|
||||
IPricingClient pricingClient,
|
||||
IReferenceEventService referenceEventService,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationApiKeyRepository organizationApiKeyRepository,
|
||||
IApplicationCacheService applicationCacheService,
|
||||
ICollectionRepository collectionRepository)
|
||||
{
|
||||
_currentContext = currentContext;
|
||||
_pricingClient = pricingClient;
|
||||
_referenceEventService = referenceEventService;
|
||||
_organizationRepository = organizationRepository;
|
||||
_organizationApiKeyRepository = organizationApiKeyRepository;
|
||||
_applicationCacheService = applicationCacheService;
|
||||
_collectionRepository = collectionRepository;
|
||||
}
|
||||
|
||||
public async Task<ProviderClientOrganizationSignUpResponse> SignUpClientOrganizationAsync(OrganizationSignup signup)
|
||||
{
|
||||
var plan = await _pricingClient.GetPlanOrThrow(signup.Plan);
|
||||
|
||||
ValidatePlan(plan, signup.AdditionalSeats);
|
||||
|
||||
var organization = new Organization
|
||||
{
|
||||
// Pre-generate the org id so that we can save it with the Stripe subscription.
|
||||
Id = CoreHelpers.GenerateComb(),
|
||||
Name = signup.Name,
|
||||
BillingEmail = signup.BillingEmail,
|
||||
PlanType = plan!.Type,
|
||||
Seats = signup.AdditionalSeats,
|
||||
MaxCollections = plan.PasswordManager.MaxCollections,
|
||||
MaxStorageGb = 1,
|
||||
UsePolicies = plan.HasPolicies,
|
||||
UseSso = plan.HasSso,
|
||||
UseOrganizationDomains = plan.HasOrganizationDomains,
|
||||
UseGroups = plan.HasGroups,
|
||||
UseEvents = plan.HasEvents,
|
||||
UseDirectory = plan.HasDirectory,
|
||||
UseTotp = plan.HasTotp,
|
||||
Use2fa = plan.Has2fa,
|
||||
UseApi = plan.HasApi,
|
||||
UseResetPassword = plan.HasResetPassword,
|
||||
SelfHost = plan.HasSelfHost,
|
||||
UsersGetPremium = plan.UsersGetPremium,
|
||||
UseCustomPermissions = plan.HasCustomPermissions,
|
||||
UseScim = plan.HasScim,
|
||||
Plan = plan.Name,
|
||||
Gateway = GatewayType.Stripe,
|
||||
ReferenceData = signup.Owner.ReferenceData,
|
||||
Enabled = true,
|
||||
LicenseKey = CoreHelpers.SecureRandomString(20),
|
||||
PublicKey = signup.PublicKey,
|
||||
PrivateKey = signup.PrivateKey,
|
||||
CreationDate = DateTime.UtcNow,
|
||||
RevisionDate = DateTime.UtcNow,
|
||||
Status = OrganizationStatusType.Created,
|
||||
UsePasswordManager = true,
|
||||
// Secrets Manager not available for purchase with Consolidated Billing.
|
||||
UseSecretsManager = false,
|
||||
};
|
||||
|
||||
var returnValue = await SignUpAsync(organization, signup.CollectionName);
|
||||
|
||||
await _referenceEventService.RaiseEventAsync(
|
||||
new ReferenceEvent(ReferenceEventType.Signup, organization, _currentContext)
|
||||
{
|
||||
PlanName = plan.Name,
|
||||
PlanType = plan.Type,
|
||||
Seats = returnValue.Organization.Seats,
|
||||
SignupInitiationPath = signup.InitiationPath,
|
||||
Storage = returnValue.Organization.MaxStorageGb,
|
||||
});
|
||||
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
private static void ValidatePlan(Plan plan, int additionalSeats)
|
||||
{
|
||||
if (plan is null)
|
||||
{
|
||||
throw new BadRequestException(PlanNullErrorMessage);
|
||||
}
|
||||
|
||||
if (plan.Disabled)
|
||||
{
|
||||
throw new BadRequestException(PlanDisabledErrorMessage);
|
||||
}
|
||||
|
||||
if (additionalSeats < 0)
|
||||
{
|
||||
throw new BadRequestException(AdditionalSeatsNegativeErrorMessage);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Private helper method to create a new organization.
|
||||
/// </summary>
|
||||
private async Task<ProviderClientOrganizationSignUpResponse> SignUpAsync(
|
||||
Organization organization, string collectionName)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _organizationRepository.CreateAsync(organization);
|
||||
await _organizationApiKeyRepository.CreateAsync(new OrganizationApiKey
|
||||
{
|
||||
OrganizationId = organization.Id,
|
||||
ApiKey = CoreHelpers.SecureRandomString(30),
|
||||
Type = OrganizationApiKeyType.Default,
|
||||
RevisionDate = DateTime.UtcNow,
|
||||
});
|
||||
await _applicationCacheService.UpsertOrganizationAbilityAsync(organization);
|
||||
|
||||
Collection defaultCollection = null;
|
||||
if (!string.IsNullOrWhiteSpace(collectionName))
|
||||
{
|
||||
defaultCollection = new Collection
|
||||
{
|
||||
Name = collectionName,
|
||||
OrganizationId = organization.Id,
|
||||
CreationDate = organization.CreationDate,
|
||||
RevisionDate = organization.CreationDate
|
||||
};
|
||||
|
||||
await _collectionRepository.CreateAsync(defaultCollection, null, null);
|
||||
}
|
||||
|
||||
return new ProviderClientOrganizationSignUpResponse(organization, defaultCollection);
|
||||
}
|
||||
catch
|
||||
{
|
||||
if (organization.Id != default)
|
||||
{
|
||||
await _organizationRepository.DeleteAsync(organization);
|
||||
await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id);
|
||||
}
|
||||
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
@ -11,8 +11,6 @@ namespace Bit.Core.Services;
|
||||
|
||||
public interface IOrganizationService
|
||||
{
|
||||
Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType,
|
||||
TaxInfo taxInfo);
|
||||
Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null);
|
||||
Task ReinstateSubscriptionAsync(Guid organizationId);
|
||||
Task<string> AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb);
|
||||
@ -20,9 +18,6 @@ public interface IOrganizationService
|
||||
Task AutoAddSeatsAsync(Organization organization, int seatsToAdd);
|
||||
Task<string> AdjustSeatsAsync(Guid organizationId, int seatAdjustment);
|
||||
Task VerifyBankAsync(Guid organizationId, int amount1, int amount2);
|
||||
#nullable enable
|
||||
Task<(Organization organization, OrganizationUser organizationUser, Collection defaultCollection)> SignupClientAsync(OrganizationSignup signup);
|
||||
#nullable disable
|
||||
/// <summary>
|
||||
/// Create a new organization on a self-hosted instance
|
||||
/// </summary>
|
||||
|
@ -144,27 +144,6 @@ public class OrganizationService : IOrganizationService
|
||||
_sendOrganizationInvitesCommand = sendOrganizationInvitesCommand;
|
||||
}
|
||||
|
||||
public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken,
|
||||
PaymentMethodType paymentMethodType, TaxInfo taxInfo)
|
||||
{
|
||||
var organization = await GetOrgById(organizationId);
|
||||
if (organization == null)
|
||||
{
|
||||
throw new NotFoundException();
|
||||
}
|
||||
|
||||
await _paymentService.SaveTaxInfoAsync(organization, taxInfo);
|
||||
var updated = await _paymentService.UpdatePaymentMethodAsync(
|
||||
organization,
|
||||
paymentMethodType,
|
||||
paymentToken,
|
||||
taxInfo);
|
||||
if (updated)
|
||||
{
|
||||
await ReplaceAndUpdateCacheAsync(organization);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null)
|
||||
{
|
||||
var organization = await GetOrgById(organizationId);
|
||||
@ -431,66 +410,6 @@ public class OrganizationService : IOrganizationService
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<(Organization organization, OrganizationUser organizationUser, Collection defaultCollection)> SignupClientAsync(OrganizationSignup signup)
|
||||
{
|
||||
var plan = await _pricingClient.GetPlanOrThrow(signup.Plan);
|
||||
|
||||
ValidatePlan(plan, signup.AdditionalSeats, "Password Manager");
|
||||
|
||||
var organization = new Organization
|
||||
{
|
||||
// Pre-generate the org id so that we can save it with the Stripe subscription.
|
||||
Id = CoreHelpers.GenerateComb(),
|
||||
Name = signup.Name,
|
||||
BillingEmail = signup.BillingEmail,
|
||||
PlanType = plan!.Type,
|
||||
Seats = signup.AdditionalSeats,
|
||||
MaxCollections = plan.PasswordManager.MaxCollections,
|
||||
MaxStorageGb = 1,
|
||||
UsePolicies = plan.HasPolicies,
|
||||
UseSso = plan.HasSso,
|
||||
UseOrganizationDomains = plan.HasOrganizationDomains,
|
||||
UseGroups = plan.HasGroups,
|
||||
UseEvents = plan.HasEvents,
|
||||
UseDirectory = plan.HasDirectory,
|
||||
UseTotp = plan.HasTotp,
|
||||
Use2fa = plan.Has2fa,
|
||||
UseApi = plan.HasApi,
|
||||
UseResetPassword = plan.HasResetPassword,
|
||||
SelfHost = plan.HasSelfHost,
|
||||
UsersGetPremium = plan.UsersGetPremium,
|
||||
UseCustomPermissions = plan.HasCustomPermissions,
|
||||
UseScim = plan.HasScim,
|
||||
Plan = plan.Name,
|
||||
Gateway = GatewayType.Stripe,
|
||||
ReferenceData = signup.Owner.ReferenceData,
|
||||
Enabled = true,
|
||||
LicenseKey = CoreHelpers.SecureRandomString(20),
|
||||
PublicKey = signup.PublicKey,
|
||||
PrivateKey = signup.PrivateKey,
|
||||
CreationDate = DateTime.UtcNow,
|
||||
RevisionDate = DateTime.UtcNow,
|
||||
Status = OrganizationStatusType.Created,
|
||||
UsePasswordManager = true,
|
||||
// Secrets Manager not available for purchase with Consolidated Billing.
|
||||
UseSecretsManager = false,
|
||||
};
|
||||
|
||||
var returnValue = await SignUpAsync(organization, default, signup.OwnerKey, signup.CollectionName, false);
|
||||
|
||||
await _referenceEventService.RaiseEventAsync(
|
||||
new ReferenceEvent(ReferenceEventType.Signup, organization, _currentContext)
|
||||
{
|
||||
PlanName = plan.Name,
|
||||
PlanType = plan.Type,
|
||||
Seats = returnValue.Item1.Seats,
|
||||
SignupInitiationPath = signup.InitiationPath,
|
||||
Storage = returnValue.Item1.MaxStorageGb,
|
||||
});
|
||||
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
private async Task ValidateSignUpPoliciesAsync(Guid ownerId)
|
||||
{
|
||||
var anySingleOrgPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(ownerId, PolicyType.SingleOrg);
|
||||
@ -572,6 +491,7 @@ public class OrganizationService : IOrganizationService
|
||||
SmServiceAccounts = license.SmServiceAccounts,
|
||||
UseRiskInsights = license.UseRiskInsights,
|
||||
UseOrganizationDomains = license.UseOrganizationDomains,
|
||||
UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies,
|
||||
};
|
||||
|
||||
var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false);
|
||||
|
@ -2,9 +2,24 @@
|
||||
|
||||
public enum EmergencyAccessStatusType : byte
|
||||
{
|
||||
/// <summary>
|
||||
/// The user has been invited to be an emergency contact.
|
||||
/// </summary>
|
||||
Invited = 0,
|
||||
/// <summary>
|
||||
/// The invited user, "grantee", has accepted the request to be an emergency contact.
|
||||
/// </summary>
|
||||
Accepted = 1,
|
||||
/// <summary>
|
||||
/// The inviting user, "grantor", has approved the grantee's acceptance.
|
||||
/// </summary>
|
||||
Confirmed = 2,
|
||||
/// <summary>
|
||||
/// The grantee has initiated the recovery process.
|
||||
/// </summary>
|
||||
RecoveryInitiated = 3,
|
||||
/// <summary>
|
||||
/// The grantee has excercised their emergency access.
|
||||
/// </summary>
|
||||
RecoveryApproved = 4,
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ using Bit.Core.Auth.Entities;
|
||||
using Bit.Core.Auth.Enums;
|
||||
using Bit.Core.Auth.Models.Data;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Vault.Models.Data;
|
||||
|
||||
@ -20,6 +21,15 @@ public interface IEmergencyAccessService
|
||||
Task InitiateAsync(Guid id, User initiatingUser);
|
||||
Task ApproveAsync(Guid id, User approvingUser);
|
||||
Task RejectAsync(Guid id, User rejectingUser);
|
||||
/// <summary>
|
||||
/// This request is made by the Grantee user to fetch the policies <see cref="Policy"/> for the Grantor User.
|
||||
/// The Grantor User has to be the owner of the organization. <see cref="OrganizationUserType"/>
|
||||
/// If the Grantor user has OrganizationUserType.Owner then the policies for the _Grantor_ user
|
||||
/// are returned.
|
||||
/// </summary>
|
||||
/// <param name="id">EmergencyAccess.Id being acted on</param>
|
||||
/// <param name="requestingUser">User making the request, this is the Grantee</param>
|
||||
/// <returns>null if the GrantorUser is not an organization owner; A list of policies otherwise.</returns>
|
||||
Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser);
|
||||
Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser);
|
||||
Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key);
|
||||
|
@ -3,7 +3,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
|
||||
using Bit.Core.AdminConsole.Repositories;
|
||||
using Bit.Core.Auth.Entities;
|
||||
using Bit.Core.Auth.Enums;
|
||||
using Bit.Core.Auth.Models;
|
||||
using Bit.Core.Auth.Models.Business.Tokenables;
|
||||
using Bit.Core.Auth.Models.Data;
|
||||
using Bit.Core.Entities;
|
||||
@ -16,7 +15,6 @@ using Bit.Core.Tokens;
|
||||
using Bit.Core.Vault.Models.Data;
|
||||
using Bit.Core.Vault.Repositories;
|
||||
using Bit.Core.Vault.Services;
|
||||
using Microsoft.AspNetCore.Identity;
|
||||
|
||||
namespace Bit.Core.Auth.Services;
|
||||
|
||||
@ -31,8 +29,6 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
private readonly IMailService _mailService;
|
||||
private readonly IUserService _userService;
|
||||
private readonly GlobalSettings _globalSettings;
|
||||
private readonly IPasswordHasher<User> _passwordHasher;
|
||||
private readonly IOrganizationService _organizationService;
|
||||
private readonly IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> _dataProtectorTokenizer;
|
||||
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
|
||||
|
||||
@ -45,9 +41,7 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
ICipherService cipherService,
|
||||
IMailService mailService,
|
||||
IUserService userService,
|
||||
IPasswordHasher<User> passwordHasher,
|
||||
GlobalSettings globalSettings,
|
||||
IOrganizationService organizationService,
|
||||
IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> dataProtectorTokenizer,
|
||||
IRemoveOrganizationUserCommand removeOrganizationUserCommand)
|
||||
{
|
||||
@ -59,9 +53,7 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
_cipherService = cipherService;
|
||||
_mailService = mailService;
|
||||
_userService = userService;
|
||||
_passwordHasher = passwordHasher;
|
||||
_globalSettings = globalSettings;
|
||||
_organizationService = organizationService;
|
||||
_dataProtectorTokenizer = dataProtectorTokenizer;
|
||||
_removeOrganizationUserCommand = removeOrganizationUserCommand;
|
||||
}
|
||||
@ -126,7 +118,12 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
throw new BadRequestException("Emergency Access not valid.");
|
||||
}
|
||||
|
||||
if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email))
|
||||
if (!_dataProtectorTokenizer.TryUnprotect(token, out var data))
|
||||
{
|
||||
throw new BadRequestException("Invalid token.");
|
||||
}
|
||||
|
||||
if (!data.IsValid(emergencyAccessId, user.Email))
|
||||
{
|
||||
throw new BadRequestException("Invalid token.");
|
||||
}
|
||||
@ -140,6 +137,8 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
throw new BadRequestException("Invitation already accepted.");
|
||||
}
|
||||
|
||||
// TODO PM-21687
|
||||
// Might not be reachable since the Tokenable.IsValid() does an email comparison
|
||||
if (string.IsNullOrWhiteSpace(emergencyAccess.Email) ||
|
||||
!emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
|
||||
{
|
||||
@ -163,6 +162,8 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId)
|
||||
{
|
||||
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
|
||||
// TODO PM-19438/PM-21687
|
||||
// Not sure why the GrantorId and the GranteeId are supposed to be the same?
|
||||
if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId))
|
||||
{
|
||||
throw new BadRequestException("Emergency Access not valid.");
|
||||
@ -171,9 +172,9 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
await _emergencyAccessRepository.DeleteAsync(emergencyAccess);
|
||||
}
|
||||
|
||||
public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId)
|
||||
public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAccessId, string key, Guid confirmingUserId)
|
||||
{
|
||||
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId);
|
||||
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
|
||||
if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted ||
|
||||
emergencyAccess.GrantorId != confirmingUserId)
|
||||
{
|
||||
@ -224,7 +225,6 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
public async Task InitiateAsync(Guid id, User initiatingUser)
|
||||
{
|
||||
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
|
||||
|
||||
if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id ||
|
||||
emergencyAccess.Status != EmergencyAccessStatusType.Confirmed)
|
||||
{
|
||||
@ -285,6 +285,9 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
|
||||
public async Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser)
|
||||
{
|
||||
// TODO PM-21687
|
||||
// Should we look up policies here or just verify the EmergencyAccess is correct
|
||||
// and handle policy logic else where? Should this be a query/Command?
|
||||
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
|
||||
|
||||
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
|
||||
@ -295,7 +298,9 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
|
||||
|
||||
var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id);
|
||||
var isOrganizationOwner = grantorOrganizations.Any<OrganizationUser>(organization => organization.Type == OrganizationUserType.Owner);
|
||||
var isOrganizationOwner = grantorOrganizations
|
||||
.Any(organization => organization.Type == OrganizationUserType.Owner);
|
||||
|
||||
var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null;
|
||||
|
||||
return policies;
|
||||
@ -311,7 +316,8 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
}
|
||||
|
||||
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
|
||||
|
||||
// TODO PM-21687
|
||||
// Redundant check of the EmergencyAccessType -> checked in IsValidRequest() ln 308
|
||||
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
|
||||
{
|
||||
throw new BadRequestException("You cannot takeover an account that is using Key Connector.");
|
||||
@ -336,7 +342,9 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
grantor.LastPasswordChangeDate = grantor.RevisionDate;
|
||||
grantor.Key = key;
|
||||
// Disable TwoFactor providers since they will otherwise block logins
|
||||
grantor.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>());
|
||||
grantor.SetTwoFactorProviders([]);
|
||||
// Disable New Device Verification since it will otherwise block logins
|
||||
grantor.VerifyDevices = false;
|
||||
await _userRepository.ReplaceAsync(grantor);
|
||||
|
||||
// Remove grantor from all organizations unless Owner
|
||||
@ -421,12 +429,22 @@ public class EmergencyAccessService : IEmergencyAccessService
|
||||
await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token);
|
||||
}
|
||||
|
||||
private string NameOrEmail(User user)
|
||||
private static string NameOrEmail(User user)
|
||||
{
|
||||
return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name;
|
||||
}
|
||||
|
||||
private bool IsValidRequest(EmergencyAccess availableAccess, User requestingUser, EmergencyAccessType requestedAccessType)
|
||||
|
||||
/*
|
||||
* Checks if EmergencyAccess Object is null
|
||||
* Checks the requesting user is the same as the granteeUser (So we are checking for proper grantee action)
|
||||
* Status _must_ equal RecoveryApproved (This means the grantor has invited, the grantee has accepted, and the grantor has approved so the shared key exists but hasn't been exercised yet)
|
||||
* request type must equal the type of access requested (View or Takeover)
|
||||
*/
|
||||
private static bool IsValidRequest(
|
||||
EmergencyAccess availableAccess,
|
||||
User requestingUser,
|
||||
EmergencyAccessType requestedAccessType)
|
||||
{
|
||||
return availableAccess != null &&
|
||||
availableAccess.GranteeId == requestingUser.Id &&
|
||||
|
@ -2,10 +2,6 @@
|
||||
|
||||
public static class StripeConstants
|
||||
{
|
||||
public static class Prices
|
||||
{
|
||||
public const string StoragePlanPersonal = "personal-storage-gb-annually";
|
||||
}
|
||||
public static class AutomaticTaxStatus
|
||||
{
|
||||
public const string Failed = "failed";
|
||||
@ -69,6 +65,11 @@ public static class StripeConstants
|
||||
public const string USBankAccount = "us_bank_account";
|
||||
}
|
||||
|
||||
public static class Prices
|
||||
{
|
||||
public const string StoragePlanPersonal = "personal-storage-gb-annually";
|
||||
}
|
||||
|
||||
public static class ProrationBehavior
|
||||
{
|
||||
public const string AlwaysInvoice = "always_invoice";
|
||||
@ -88,6 +89,13 @@ public static class StripeConstants
|
||||
public const string Paused = "paused";
|
||||
}
|
||||
|
||||
public static class TaxExempt
|
||||
{
|
||||
public const string Exempt = "exempt";
|
||||
public const string None = "none";
|
||||
public const string Reverse = "reverse";
|
||||
}
|
||||
|
||||
public static class ValidateTaxLocationTiming
|
||||
{
|
||||
public const string Deferred = "deferred";
|
||||
|
@ -15,12 +15,7 @@ public static class CustomerExtensions
|
||||
}
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Determines if a Stripe customer supports automatic tax
|
||||
/// </summary>
|
||||
/// <param name="customer"></param>
|
||||
/// <returns></returns>
|
||||
public static bool HasTaxLocationVerified(this Customer customer) =>
|
||||
public static bool HasRecognizedTaxLocation(this Customer customer) =>
|
||||
customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation;
|
||||
|
||||
public static decimal GetBillingBalance(this Customer customer)
|
||||
|
@ -22,7 +22,7 @@ public static class SubscriptionUpdateOptionsExtensions
|
||||
}
|
||||
|
||||
// We might only need to check the automatic tax status.
|
||||
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country))
|
||||
if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ public static class UpcomingInvoiceOptionsExtensions
|
||||
}
|
||||
|
||||
// We might only need to check the automatic tax status.
|
||||
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country))
|
||||
if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
@ -309,6 +309,7 @@ public class OrganizationMigrator(
|
||||
organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb;
|
||||
organization.UsePolicies = plan.HasPolicies;
|
||||
organization.UseSso = plan.HasSso;
|
||||
organization.UseOrganizationDomains = plan.HasOrganizationDomains;
|
||||
organization.UseGroups = plan.HasGroups;
|
||||
organization.UseEvents = plan.HasEvents;
|
||||
organization.UseDirectory = plan.HasDirectory;
|
||||
|
@ -26,6 +26,7 @@ public record Enterprise2019Plan : Plan
|
||||
Has2fa = true;
|
||||
HasApi = true;
|
||||
HasSso = true;
|
||||
HasOrganizationDomains = true;
|
||||
HasKeyConnector = true;
|
||||
HasScim = true;
|
||||
HasResetPassword = true;
|
||||
|
@ -26,6 +26,7 @@ public record Enterprise2020Plan : Plan
|
||||
Has2fa = true;
|
||||
HasApi = true;
|
||||
HasSso = true;
|
||||
HasOrganizationDomains = true;
|
||||
HasKeyConnector = true;
|
||||
HasScim = true;
|
||||
HasResetPassword = true;
|
||||
|
@ -26,6 +26,7 @@ public record EnterprisePlan : Plan
|
||||
Has2fa = true;
|
||||
HasApi = true;
|
||||
HasSso = true;
|
||||
HasOrganizationDomains = true;
|
||||
HasKeyConnector = true;
|
||||
HasScim = true;
|
||||
HasResetPassword = true;
|
||||
|
@ -26,6 +26,7 @@ public record Enterprise2023Plan : Plan
|
||||
Has2fa = true;
|
||||
HasApi = true;
|
||||
HasSso = true;
|
||||
HasOrganizationDomains = true;
|
||||
HasKeyConnector = true;
|
||||
HasScim = true;
|
||||
HasResetPassword = true;
|
||||
|
@ -26,6 +26,7 @@ public record PlanAdapter : Plan
|
||||
Has2fa = HasFeature("2fa");
|
||||
HasApi = HasFeature("api");
|
||||
HasSso = HasFeature("sso");
|
||||
HasOrganizationDomains = HasFeature("organizationDomains");
|
||||
HasKeyConnector = HasFeature("keyConnector");
|
||||
HasScim = HasFeature("scim");
|
||||
HasResetPassword = HasFeature("resetPassword");
|
||||
|
@ -1,11 +1,11 @@
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.Billing.Caches;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Extensions;
|
||||
using Bit.Core.Billing.Models;
|
||||
using Bit.Core.Billing.Models.Sales;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Billing.Services.Contracts;
|
||||
using Bit.Core.Billing.Tax.Models;
|
||||
using Bit.Core.Billing.Tax.Services;
|
||||
using Bit.Core.Enums;
|
||||
@ -35,16 +35,15 @@ public class OrganizationBillingService(
|
||||
ISetupIntentCache setupIntentCache,
|
||||
IStripeAdapter stripeAdapter,
|
||||
ISubscriberService subscriberService,
|
||||
ITaxService taxService,
|
||||
IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService
|
||||
ITaxService taxService) : IOrganizationBillingService
|
||||
{
|
||||
public async Task Finalize(OrganizationSale sale)
|
||||
{
|
||||
var (organization, customerSetup, subscriptionSetup) = sale;
|
||||
|
||||
var customer = string.IsNullOrEmpty(organization.GatewayCustomerId) && customerSetup != null
|
||||
? await CreateCustomerAsync(organization, customerSetup)
|
||||
: await subscriberService.GetCustomerOrThrow(organization, new CustomerGetOptions { Expand = ["tax", "tax_ids"] });
|
||||
? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType)
|
||||
: await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup);
|
||||
|
||||
var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup);
|
||||
|
||||
@ -121,7 +120,8 @@ public class OrganizationBillingService(
|
||||
subscription.CurrentPeriodEnd);
|
||||
}
|
||||
|
||||
public async Task UpdatePaymentMethod(
|
||||
public async Task
|
||||
UpdatePaymentMethod(
|
||||
Organization organization,
|
||||
TokenizedPaymentSource tokenizedPaymentSource,
|
||||
TaxInformation taxInformation)
|
||||
@ -151,8 +151,11 @@ public class OrganizationBillingService(
|
||||
|
||||
private async Task<Customer> CreateCustomerAsync(
|
||||
Organization organization,
|
||||
CustomerSetup customerSetup)
|
||||
CustomerSetup customerSetup,
|
||||
PlanType? updatedPlanType = null)
|
||||
{
|
||||
var planType = updatedPlanType ?? organization.PlanType;
|
||||
|
||||
var displayName = organization.DisplayName();
|
||||
|
||||
var customerCreateOptions = new CustomerCreateOptions
|
||||
@ -212,13 +215,24 @@ public class OrganizationBillingService(
|
||||
City = customerSetup.TaxInformation.City,
|
||||
PostalCode = customerSetup.TaxInformation.PostalCode,
|
||||
State = customerSetup.TaxInformation.State,
|
||||
Country = customerSetup.TaxInformation.Country,
|
||||
Country = customerSetup.TaxInformation.Country
|
||||
};
|
||||
|
||||
customerCreateOptions.Tax = new CustomerTaxOptions
|
||||
{
|
||||
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
|
||||
};
|
||||
|
||||
var setNonUSBusinessUseToReverseCharge =
|
||||
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
|
||||
|
||||
if (setNonUSBusinessUseToReverseCharge &&
|
||||
planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families &&
|
||||
customerSetup.TaxInformation.Country != "US")
|
||||
{
|
||||
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse;
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId))
|
||||
{
|
||||
var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country,
|
||||
@ -399,21 +413,68 @@ public class OrganizationBillingService(
|
||||
TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays
|
||||
};
|
||||
|
||||
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
|
||||
var setNonUSBusinessUseToReverseCharge =
|
||||
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
|
||||
|
||||
if (setNonUSBusinessUseToReverseCharge)
|
||||
{
|
||||
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType);
|
||||
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
|
||||
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
|
||||
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
|
||||
}
|
||||
else
|
||||
else if (customer.HasRecognizedTaxLocation())
|
||||
{
|
||||
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions();
|
||||
subscriptionCreateOptions.AutomaticTax.Enabled = customer.HasBillingLocation();
|
||||
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled =
|
||||
subscriptionSetup.PlanType.GetProductTier() == ProductTierType.Families ||
|
||||
customer.Address.Country == "US" ||
|
||||
customer.TaxIds.Any()
|
||||
};
|
||||
}
|
||||
|
||||
return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
|
||||
}
|
||||
|
||||
private async Task<Customer> GetCustomerWhileEnsuringCorrectTaxExemptionAsync(
|
||||
Organization organization,
|
||||
SubscriptionSetup subscriptionSetup)
|
||||
{
|
||||
var customer = await subscriberService.GetCustomerOrThrow(organization,
|
||||
new CustomerGetOptions { Expand = ["tax", "tax_ids"] });
|
||||
|
||||
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
|
||||
|
||||
if (!setNonUSBusinessUseToReverseCharge || subscriptionSetup.PlanType.GetProductTier() is
|
||||
not (ProductTierType.Teams or
|
||||
ProductTierType.TeamsStarter or
|
||||
ProductTierType.Enterprise))
|
||||
{
|
||||
return customer;
|
||||
}
|
||||
|
||||
List<string> expansions = ["tax", "tax_ids"];
|
||||
|
||||
customer = customer switch
|
||||
{
|
||||
{ Address.Country: not "US", TaxExempt: not StripeConstants.TaxExempt.Reverse } => await
|
||||
stripeAdapter.CustomerUpdateAsync(customer.Id,
|
||||
new CustomerUpdateOptions
|
||||
{
|
||||
Expand = expansions,
|
||||
TaxExempt = StripeConstants.TaxExempt.Reverse
|
||||
}),
|
||||
{ Address.Country: "US", TaxExempt: StripeConstants.TaxExempt.Reverse } => await
|
||||
stripeAdapter.CustomerUpdateAsync(customer.Id,
|
||||
new CustomerUpdateOptions
|
||||
{
|
||||
Expand = expansions,
|
||||
TaxExempt = StripeConstants.TaxExempt.None
|
||||
}),
|
||||
_ => customer
|
||||
};
|
||||
|
||||
return customer;
|
||||
}
|
||||
|
||||
private async Task<bool> IsEligibleForSelfHostAsync(
|
||||
Organization organization)
|
||||
{
|
||||
|
@ -3,8 +3,6 @@ using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Models;
|
||||
using Bit.Core.Billing.Models.Sales;
|
||||
using Bit.Core.Billing.Tax.Models;
|
||||
using Bit.Core.Billing.Tax.Services;
|
||||
using Bit.Core.Billing.Tax.Services.Implementations;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Exceptions;
|
||||
@ -12,7 +10,6 @@ using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Settings;
|
||||
using Braintree;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Stripe;
|
||||
using Customer = Stripe.Customer;
|
||||
@ -24,20 +21,18 @@ using static Utilities;
|
||||
|
||||
public class PremiumUserBillingService(
|
||||
IBraintreeGateway braintreeGateway,
|
||||
IFeatureService featureService,
|
||||
IGlobalSettings globalSettings,
|
||||
ILogger<PremiumUserBillingService> logger,
|
||||
ISetupIntentCache setupIntentCache,
|
||||
IStripeAdapter stripeAdapter,
|
||||
ISubscriberService subscriberService,
|
||||
IUserRepository userRepository,
|
||||
[FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService
|
||||
IUserRepository userRepository) : IPremiumUserBillingService
|
||||
{
|
||||
public async Task Credit(User user, decimal amount)
|
||||
{
|
||||
var customer = await subscriberService.GetCustomer(user);
|
||||
|
||||
// Negative credit represents a balance and all Stripe denomination is in cents.
|
||||
// Negative credit represents a balance, and all Stripe denomination is in cents.
|
||||
var credit = (long)(amount * -100);
|
||||
|
||||
if (customer == null)
|
||||
@ -184,7 +179,7 @@ public class PremiumUserBillingService(
|
||||
City = customerSetup.TaxInformation.City,
|
||||
PostalCode = customerSetup.TaxInformation.PostalCode,
|
||||
State = customerSetup.TaxInformation.State,
|
||||
Country = customerSetup.TaxInformation.Country,
|
||||
Country = customerSetup.TaxInformation.Country
|
||||
},
|
||||
Description = user.Name,
|
||||
Email = user.Email,
|
||||
@ -324,6 +319,10 @@ public class PremiumUserBillingService(
|
||||
|
||||
var subscriptionCreateOptions = new SubscriptionCreateOptions
|
||||
{
|
||||
AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = true
|
||||
},
|
||||
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
|
||||
Customer = customer.Id,
|
||||
Items = subscriptionItemOptionsList,
|
||||
@ -337,18 +336,6 @@ public class PremiumUserBillingService(
|
||||
OffSession = true
|
||||
};
|
||||
|
||||
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
|
||||
{
|
||||
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
|
||||
}
|
||||
else
|
||||
{
|
||||
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported,
|
||||
};
|
||||
}
|
||||
|
||||
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
|
||||
|
||||
if (usingPayPal)
|
||||
@ -380,7 +367,7 @@ public class PremiumUserBillingService(
|
||||
City = taxInformation.City,
|
||||
PostalCode = taxInformation.PostalCode,
|
||||
State = taxInformation.State,
|
||||
Country = taxInformation.Country,
|
||||
Country = taxInformation.Country
|
||||
},
|
||||
Expand = ["tax"],
|
||||
Tax = new CustomerTaxOptions
|
||||
|
@ -1,7 +1,10 @@
|
||||
using Bit.Core.Billing.Caches;
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.Billing.Caches;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Extensions;
|
||||
using Bit.Core.Billing.Models;
|
||||
using Bit.Core.Billing.Services.Contracts;
|
||||
using Bit.Core.Billing.Tax.Models;
|
||||
using Bit.Core.Billing.Tax.Services;
|
||||
using Bit.Core.Entities;
|
||||
@ -28,8 +31,7 @@ public class SubscriberService(
|
||||
ILogger<SubscriberService> logger,
|
||||
ISetupIntentCache setupIntentCache,
|
||||
IStripeAdapter stripeAdapter,
|
||||
ITaxService taxService,
|
||||
IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService
|
||||
ITaxService taxService) : ISubscriberService
|
||||
{
|
||||
public async Task CancelSubscription(
|
||||
ISubscriber subscriber,
|
||||
@ -128,7 +130,7 @@ public class SubscriberService(
|
||||
[subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion
|
||||
},
|
||||
Email = subscriber.BillingEmailAddress(),
|
||||
PaymentMethodNonce = paymentMethodNonce,
|
||||
PaymentMethodNonce = paymentMethodNonce
|
||||
});
|
||||
|
||||
if (customerResult.IsSuccess())
|
||||
@ -482,7 +484,7 @@ public class SubscriberService(
|
||||
|
||||
var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First();
|
||||
|
||||
// Find the customer's existing setup intents that should be cancelled.
|
||||
// Find the customer's existing setup intents that should be canceled.
|
||||
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
|
||||
.Where(si =>
|
||||
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
|
||||
@ -519,7 +521,7 @@ public class SubscriberService(
|
||||
await stripeAdapter.PaymentMethodAttachAsync(token,
|
||||
new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId });
|
||||
|
||||
// Find the customer's existing setup intents that should be cancelled.
|
||||
// Find the customer's existing setup intents that should be canceled.
|
||||
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
|
||||
.Where(si =>
|
||||
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
|
||||
@ -637,7 +639,8 @@ public class SubscriberService(
|
||||
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
|
||||
taxInformation.Country,
|
||||
taxInformation.TaxId);
|
||||
throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError");
|
||||
|
||||
throw new BadRequestException("billingTaxIdTypeInferenceError");
|
||||
}
|
||||
}
|
||||
|
||||
@ -654,53 +657,84 @@ public class SubscriberService(
|
||||
logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.",
|
||||
taxInformation.TaxId,
|
||||
taxInformation.Country);
|
||||
throw new Exceptions.BadRequestException("billingInvalidTaxIdError");
|
||||
|
||||
throw new BadRequestException("billingInvalidTaxIdError");
|
||||
|
||||
default:
|
||||
logger.LogError(e,
|
||||
"Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.",
|
||||
taxInformation.TaxId,
|
||||
taxInformation.Country,
|
||||
customer.Id);
|
||||
throw new Exceptions.BadRequestException("billingTaxIdCreationError");
|
||||
|
||||
throw new BadRequestException("billingTaxIdCreationError");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
|
||||
var subscription =
|
||||
customer.Subscriptions.First(subscription => subscription.Id == subscriber.GatewaySubscriptionId);
|
||||
|
||||
var isBusinessUseSubscriber = subscriber switch
|
||||
{
|
||||
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
|
||||
Organization organization => organization.PlanType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families,
|
||||
Provider => true,
|
||||
_ => false
|
||||
};
|
||||
|
||||
var setNonUSBusinessUseToReverseCharge =
|
||||
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
|
||||
|
||||
if (setNonUSBusinessUseToReverseCharge && isBusinessUseSubscriber)
|
||||
{
|
||||
switch (customer)
|
||||
{
|
||||
var subscriptionGetOptions = new SubscriptionGetOptions
|
||||
case
|
||||
{
|
||||
Expand = ["customer.tax", "customer.tax_ids"]
|
||||
};
|
||||
var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions);
|
||||
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id));
|
||||
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
|
||||
var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
|
||||
if (automaticTaxOptions?.AutomaticTax?.Enabled != null)
|
||||
Address.Country: not "US",
|
||||
TaxExempt: not StripeConstants.TaxExempt.Reverse
|
||||
}:
|
||||
await stripeAdapter.CustomerUpdateAsync(customer.Id,
|
||||
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
|
||||
break;
|
||||
case
|
||||
{
|
||||
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions);
|
||||
}
|
||||
Address.Country: "US",
|
||||
TaxExempt: StripeConstants.TaxExempt.Reverse
|
||||
}:
|
||||
await stripeAdapter.CustomerUpdateAsync(customer.Id,
|
||||
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.None });
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (SubscriberIsEligibleForAutomaticTax(subscriber, customer))
|
||||
|
||||
if (!subscription.AutomaticTax.Enabled)
|
||||
{
|
||||
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId,
|
||||
await stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
|
||||
new SubscriptionUpdateOptions
|
||||
{
|
||||
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
|
||||
});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
var automaticTaxShouldBeEnabled = subscriber switch
|
||||
{
|
||||
User => true,
|
||||
Organization organization => organization.PlanType.GetProductTier() == ProductTierType.Families ||
|
||||
customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false),
|
||||
Provider => customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false),
|
||||
_ => false
|
||||
};
|
||||
|
||||
return;
|
||||
|
||||
bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer)
|
||||
=> !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) &&
|
||||
(localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) &&
|
||||
localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported;
|
||||
if (automaticTaxShouldBeEnabled && !subscription.AutomaticTax.Enabled)
|
||||
{
|
||||
await stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
|
||||
new SubscriptionUpdateOptions
|
||||
{
|
||||
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,7 +76,7 @@ public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : I
|
||||
|
||||
private bool ShouldBeEnabled(Customer customer)
|
||||
{
|
||||
if (!customer.HasTaxLocationVerified())
|
||||
if (!customer.HasRecognizedTaxLocation())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
@ -59,6 +59,6 @@ public class PersonalUseAutomaticTaxStrategy(IFeatureService featureService) : I
|
||||
|
||||
private static bool ShouldBeEnabled(Customer customer)
|
||||
{
|
||||
return customer.HasTaxLocationVerified();
|
||||
return customer.HasRecognizedTaxLocation();
|
||||
}
|
||||
}
|
||||
|
@ -115,6 +115,7 @@ public static class FeatureFlagKeys
|
||||
public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence";
|
||||
public const string EmailVerification = "email-verification";
|
||||
public const string UnauthenticatedExtensionUIRefresh = "unauth-ui-refresh";
|
||||
public const string BrowserExtensionLoginApproval = "pm-14938-browser-extension-login-approvals";
|
||||
public const string SetInitialPasswordRefactor = "pm-16117-set-initial-password-refactor";
|
||||
public const string ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor";
|
||||
public const string RecoveryCodeLogin = "pm-17128-recovery-code-login";
|
||||
@ -142,13 +143,13 @@ public static class FeatureFlagKeys
|
||||
public const string UsePricingService = "use-pricing-service";
|
||||
public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features";
|
||||
public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method";
|
||||
public const string PM19147_AutomaticTaxImprovements = "pm-19147-automatic-tax-improvements";
|
||||
public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates";
|
||||
public const string PM18770_EnableOrganizationBusinessUnitConversion = "pm-18770-enable-organization-business-unit-conversion";
|
||||
public const string PM199566_UpdateMSPToChargeAutomatically = "pm-199566-update-msp-to-charge-automatically";
|
||||
public const string PM19956_RequireProviderPaymentMethodDuringSetup = "pm-19956-require-provider-payment-method-during-setup";
|
||||
public const string UseOrganizationWarningsService = "use-organization-warnings-service";
|
||||
public const string PM20322_AllowTrialLength0 = "pm-20322-allow-trial-length-0";
|
||||
public const string PM21092_SetNonUSBusinessUseToReverseCharge = "pm-21092-set-non-us-business-use-to-reverse-charge";
|
||||
|
||||
/* Data Insights and Reporting Team */
|
||||
public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application";
|
||||
|
@ -84,6 +84,7 @@ public class OrganizationLicense : ILicense
|
||||
SmSeats = org.SmSeats;
|
||||
SmServiceAccounts = org.SmServiceAccounts;
|
||||
UseRiskInsights = org.UseRiskInsights;
|
||||
UseOrganizationDomains = org.UseOrganizationDomains;
|
||||
|
||||
// Deprecated. Left for backwards compatibility with old license versions.
|
||||
LimitCollectionCreationDeletion = org.LimitCollectionCreation || org.LimitCollectionDeletion;
|
||||
@ -195,10 +196,10 @@ public class OrganizationLicense : ILicense
|
||||
/// <remarks>Intentionally set one version behind to allow self hosted users some time to update before
|
||||
/// getting out of date license errors
|
||||
/// </remarks>
|
||||
public const int CurrentLicenseFileVersion = 14;
|
||||
public const int CurrentLicenseFileVersion = 15;
|
||||
private bool ValidLicenseVersion
|
||||
{
|
||||
get => Version is >= 1 and <= 15;
|
||||
get => Version is >= 1 and <= 16;
|
||||
}
|
||||
|
||||
public byte[] GetDataBytes(bool forHash = false)
|
||||
@ -244,6 +245,8 @@ public class OrganizationLicense : ILicense
|
||||
(Version >= 14 || !p.Name.Equals(nameof(LimitCollectionCreationDeletion))) &&
|
||||
// AllowAdminAccessToAllCollectionItems was added in Version 15
|
||||
(Version >= 15 || !p.Name.Equals(nameof(AllowAdminAccessToAllCollectionItems))) &&
|
||||
// UseOrganizationDomains was added in Version 16
|
||||
(Version >= 16 || !p.Name.Equals(nameof(UseOrganizationDomains))) &&
|
||||
(
|
||||
!forHash ||
|
||||
(
|
||||
@ -252,7 +255,10 @@ public class OrganizationLicense : ILicense
|
||||
!p.Name.Equals(nameof(Refresh))
|
||||
)
|
||||
) &&
|
||||
!p.Name.Equals(nameof(UseRiskInsights)))
|
||||
// any new fields added need to be added here so that they're ignored
|
||||
!p.Name.Equals(nameof(UseRiskInsights)) &&
|
||||
!p.Name.Equals(nameof(UseAdminSponsoredFamilies)) &&
|
||||
!p.Name.Equals(nameof(UseOrganizationDomains)))
|
||||
.OrderBy(p => p.Name)
|
||||
.Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}")
|
||||
.Aggregate((c, n) => $"{c}|{n}");
|
||||
@ -583,6 +589,11 @@ public class OrganizationLicense : ILicense
|
||||
* validation.
|
||||
*/
|
||||
|
||||
if (valid && Version >= 16)
|
||||
{
|
||||
valid = organization.UseOrganizationDomains;
|
||||
}
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
|
@ -69,8 +69,11 @@ public static class OrganizationServiceCollectionExtensions
|
||||
services.AddBaseOrganizationSubscriptionCommandsQueries();
|
||||
}
|
||||
|
||||
private static IServiceCollection AddOrganizationSignUpCommands(this IServiceCollection services) =>
|
||||
private static void AddOrganizationSignUpCommands(this IServiceCollection services)
|
||||
{
|
||||
services.AddScoped<ICloudOrganizationSignUpCommand, CloudOrganizationSignUpCommand>();
|
||||
services.AddScoped<IProviderClientOrganizationSignUpCommand, ProviderClientOrganizationSignUpCommand>();
|
||||
}
|
||||
|
||||
private static void AddOrganizationDeleteCommands(this IServiceCollection services)
|
||||
{
|
||||
|
@ -4,7 +4,6 @@ using Bit.Core.Billing.Models;
|
||||
using Bit.Core.Billing.Tax.Requests;
|
||||
using Bit.Core.Billing.Tax.Responses;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Models.Business;
|
||||
using Bit.Core.Models.StaticStore;
|
||||
|
||||
@ -30,8 +29,6 @@ public interface IPaymentService
|
||||
Task<string> AdjustServiceAccountsAsync(Organization organization, Plan plan, int additionalServiceAccounts);
|
||||
Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false);
|
||||
Task ReinstateSubscriptionAsync(ISubscriber subscriber);
|
||||
Task<bool> UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType,
|
||||
string paymentToken, TaxInfo taxInfo = null);
|
||||
Task<bool> CreditAccountAsync(ISubscriber subscriber, decimal creditAmount);
|
||||
Task<BillingInfo> GetBillingAsync(ISubscriber subscriber);
|
||||
Task<BillingHistoryInfo> GetBillingHistoryAsync(ISubscriber subscriber);
|
||||
|
@ -1,13 +1,13 @@
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.AdminConsole.Models.Business;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Extensions;
|
||||
using Bit.Core.Billing.Models;
|
||||
using Bit.Core.Billing.Models.Business;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Billing.Services;
|
||||
using Bit.Core.Billing.Services.Contracts;
|
||||
using Bit.Core.Billing.Tax.Models;
|
||||
using Bit.Core.Billing.Tax.Requests;
|
||||
using Bit.Core.Billing.Tax.Responses;
|
||||
using Bit.Core.Billing.Tax.Services;
|
||||
@ -38,7 +38,6 @@ public class StripePaymentService : IPaymentService
|
||||
private readonly IGlobalSettings _globalSettings;
|
||||
private readonly IFeatureService _featureService;
|
||||
private readonly ITaxService _taxService;
|
||||
private readonly ISubscriberService _subscriberService;
|
||||
private readonly IPricingClient _pricingClient;
|
||||
private readonly IAutomaticTaxFactory _automaticTaxFactory;
|
||||
private readonly IAutomaticTaxStrategy _personalUseTaxStrategy;
|
||||
@ -51,7 +50,6 @@ public class StripePaymentService : IPaymentService
|
||||
IGlobalSettings globalSettings,
|
||||
IFeatureService featureService,
|
||||
ITaxService taxService,
|
||||
ISubscriberService subscriberService,
|
||||
IPricingClient pricingClient,
|
||||
IAutomaticTaxFactory automaticTaxFactory,
|
||||
[FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy personalUseTaxStrategy)
|
||||
@ -63,7 +61,6 @@ public class StripePaymentService : IPaymentService
|
||||
_globalSettings = globalSettings;
|
||||
_featureService = featureService;
|
||||
_taxService = taxService;
|
||||
_subscriberService = subscriberService;
|
||||
_pricingClient = pricingClient;
|
||||
_automaticTaxFactory = automaticTaxFactory;
|
||||
_personalUseTaxStrategy = personalUseTaxStrategy;
|
||||
@ -136,15 +133,68 @@ public class StripePaymentService : IPaymentService
|
||||
|
||||
if (subscriptionUpdate is CompleteSubscriptionUpdate)
|
||||
{
|
||||
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
|
||||
var setNonUSBusinessUseToReverseCharge =
|
||||
_featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
|
||||
|
||||
if (setNonUSBusinessUseToReverseCharge)
|
||||
{
|
||||
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, updatedItemOptions.Select(x => x.Plan ?? x.Price));
|
||||
var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters);
|
||||
automaticTaxStrategy.SetUpdateOptions(subUpdateOptions, sub);
|
||||
if (sub.Customer is
|
||||
{
|
||||
Address.Country: not "US",
|
||||
TaxExempt: not StripeConstants.TaxExempt.Reverse
|
||||
})
|
||||
{
|
||||
await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId,
|
||||
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
|
||||
}
|
||||
|
||||
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
|
||||
}
|
||||
else
|
||||
else if (sub.Customer.HasRecognizedTaxLocation())
|
||||
{
|
||||
subUpdateOptions.EnableAutomaticTax(sub.Customer, sub);
|
||||
switch (subscriber)
|
||||
{
|
||||
case User:
|
||||
{
|
||||
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
|
||||
break;
|
||||
}
|
||||
case Organization:
|
||||
{
|
||||
if (sub.Customer.Address.Country == "US")
|
||||
{
|
||||
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
|
||||
}
|
||||
else
|
||||
{
|
||||
var familyPriceIds = (await Task.WhenAll(
|
||||
_pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019),
|
||||
_pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually)))
|
||||
.Select(plan => plan.PasswordManager.StripePlanId);
|
||||
|
||||
var updateIsForPersonalUse = updatedItemOptions
|
||||
.Select(option => option.Price)
|
||||
.Intersect(familyPriceIds)
|
||||
.Any();
|
||||
|
||||
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = updateIsForPersonalUse || sub.Customer.TaxIds.Any()
|
||||
};
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case Provider:
|
||||
{
|
||||
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
|
||||
{
|
||||
Enabled = sub.Customer.Address.Country == "US" ||
|
||||
sub.Customer.TaxIds.Any()
|
||||
};
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -202,7 +252,7 @@ public class StripePaymentService : IPaymentService
|
||||
}
|
||||
else if (!invoice.Paid)
|
||||
{
|
||||
// Pay invoice with no charge to customer this completes the invoice immediately without waiting the scheduled 1h
|
||||
// Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h
|
||||
invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId);
|
||||
paymentIntentClientSecret = null;
|
||||
}
|
||||
@ -585,309 +635,6 @@ public class StripePaymentService : IPaymentService
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<bool> UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType,
|
||||
string paymentToken, TaxInfo taxInfo = null)
|
||||
{
|
||||
if (subscriber == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(subscriber));
|
||||
}
|
||||
|
||||
if (subscriber.Gateway.HasValue && subscriber.Gateway.Value != GatewayType.Stripe)
|
||||
{
|
||||
throw new GatewayException("Switching from one payment type to another is not supported. " +
|
||||
"Contact us for assistance.");
|
||||
}
|
||||
|
||||
var createdCustomer = false;
|
||||
Braintree.Customer braintreeCustomer = null;
|
||||
string stipeCustomerSourceToken = null;
|
||||
string stipeCustomerPaymentMethodId = null;
|
||||
var stripeCustomerMetadata = new Dictionary<string, string>
|
||||
{
|
||||
{ "region", _globalSettings.BaseServiceUri.CloudRegion }
|
||||
};
|
||||
var stripePaymentMethod = paymentMethodType is PaymentMethodType.Card or PaymentMethodType.BankAccount;
|
||||
|
||||
Customer customer = null;
|
||||
|
||||
if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
|
||||
{
|
||||
var options = new CustomerGetOptions { Expand = ["sources", "tax", "subscriptions"] };
|
||||
customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, options);
|
||||
if (customer.Metadata?.Any() ?? false)
|
||||
{
|
||||
stripeCustomerMetadata = customer.Metadata;
|
||||
}
|
||||
}
|
||||
|
||||
var hadBtCustomer = stripeCustomerMetadata.ContainsKey("btCustomerId");
|
||||
if (stripePaymentMethod)
|
||||
{
|
||||
if (paymentToken.StartsWith("pm_"))
|
||||
{
|
||||
stipeCustomerPaymentMethodId = paymentToken;
|
||||
}
|
||||
else
|
||||
{
|
||||
stipeCustomerSourceToken = paymentToken;
|
||||
}
|
||||
}
|
||||
else if (paymentMethodType == PaymentMethodType.PayPal)
|
||||
{
|
||||
if (hadBtCustomer)
|
||||
{
|
||||
var pmResult = await _btGateway.PaymentMethod.CreateAsync(new Braintree.PaymentMethodRequest
|
||||
{
|
||||
CustomerId = stripeCustomerMetadata["btCustomerId"],
|
||||
PaymentMethodNonce = paymentToken
|
||||
});
|
||||
|
||||
if (pmResult.IsSuccess())
|
||||
{
|
||||
var customerResult = await _btGateway.Customer.UpdateAsync(
|
||||
stripeCustomerMetadata["btCustomerId"], new Braintree.CustomerRequest
|
||||
{
|
||||
DefaultPaymentMethodToken = pmResult.Target.Token
|
||||
});
|
||||
|
||||
if (customerResult.IsSuccess() && customerResult.Target.PaymentMethods.Length > 0)
|
||||
{
|
||||
braintreeCustomer = customerResult.Target;
|
||||
}
|
||||
else
|
||||
{
|
||||
await _btGateway.PaymentMethod.DeleteAsync(pmResult.Target.Token);
|
||||
hadBtCustomer = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
hadBtCustomer = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!hadBtCustomer)
|
||||
{
|
||||
var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest
|
||||
{
|
||||
PaymentMethodNonce = paymentToken,
|
||||
Email = subscriber.BillingEmailAddress(),
|
||||
Id = subscriber.BraintreeCustomerIdPrefix() + subscriber.Id.ToString("N").ToLower() +
|
||||
Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false),
|
||||
CustomFields = new Dictionary<string, string>
|
||||
{
|
||||
[subscriber.BraintreeIdField()] = subscriber.Id.ToString(),
|
||||
[subscriber.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion
|
||||
}
|
||||
});
|
||||
|
||||
if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0)
|
||||
{
|
||||
throw new GatewayException("Failed to create PayPal customer record.");
|
||||
}
|
||||
|
||||
braintreeCustomer = customerResult.Target;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new GatewayException("Payment method is not supported at this time.");
|
||||
}
|
||||
|
||||
if (stripeCustomerMetadata.ContainsKey("btCustomerId"))
|
||||
{
|
||||
if (braintreeCustomer?.Id != stripeCustomerMetadata["btCustomerId"])
|
||||
{
|
||||
stripeCustomerMetadata["btCustomerId_old"] = stripeCustomerMetadata["btCustomerId"];
|
||||
}
|
||||
|
||||
stripeCustomerMetadata["btCustomerId"] = braintreeCustomer?.Id;
|
||||
}
|
||||
else if (!string.IsNullOrWhiteSpace(braintreeCustomer?.Id))
|
||||
{
|
||||
stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber))
|
||||
{
|
||||
taxInfo.TaxIdType = taxInfo.TaxIdType ??
|
||||
_taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber);
|
||||
}
|
||||
|
||||
if (customer == null)
|
||||
{
|
||||
customer = await _stripeAdapter.CustomerCreateAsync(new CustomerCreateOptions
|
||||
{
|
||||
Description = subscriber.BillingName(),
|
||||
Email = subscriber.BillingEmailAddress(),
|
||||
Metadata = stripeCustomerMetadata,
|
||||
Source = stipeCustomerSourceToken,
|
||||
PaymentMethod = stipeCustomerPaymentMethodId,
|
||||
InvoiceSettings = new CustomerInvoiceSettingsOptions
|
||||
{
|
||||
DefaultPaymentMethod = stipeCustomerPaymentMethodId,
|
||||
CustomFields =
|
||||
[
|
||||
new CustomerInvoiceSettingsCustomFieldOptions()
|
||||
{
|
||||
Name = subscriber.SubscriberType(),
|
||||
Value = subscriber.GetFormattedInvoiceName()
|
||||
}
|
||||
|
||||
]
|
||||
},
|
||||
Address = taxInfo == null ? null : new AddressOptions
|
||||
{
|
||||
Country = taxInfo.BillingAddressCountry,
|
||||
PostalCode = taxInfo.BillingAddressPostalCode,
|
||||
Line1 = taxInfo.BillingAddressLine1 ?? string.Empty,
|
||||
Line2 = taxInfo.BillingAddressLine2,
|
||||
City = taxInfo.BillingAddressCity,
|
||||
State = taxInfo.BillingAddressState
|
||||
},
|
||||
TaxIdData = string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber)
|
||||
? []
|
||||
: [
|
||||
new CustomerTaxIdDataOptions
|
||||
{
|
||||
Type = taxInfo.TaxIdType,
|
||||
Value = taxInfo.TaxIdNumber
|
||||
}
|
||||
],
|
||||
Expand = ["sources", "tax", "subscriptions"],
|
||||
});
|
||||
|
||||
subscriber.Gateway = GatewayType.Stripe;
|
||||
subscriber.GatewayCustomerId = customer.Id;
|
||||
createdCustomer = true;
|
||||
}
|
||||
|
||||
if (!createdCustomer)
|
||||
{
|
||||
string defaultSourceId = null;
|
||||
string defaultPaymentMethodId = null;
|
||||
if (stripePaymentMethod)
|
||||
{
|
||||
if (!string.IsNullOrWhiteSpace(stipeCustomerSourceToken) && paymentToken.StartsWith("btok_"))
|
||||
{
|
||||
var bankAccount = await _stripeAdapter.BankAccountCreateAsync(customer.Id, new BankAccountCreateOptions
|
||||
{
|
||||
Source = paymentToken
|
||||
});
|
||||
defaultSourceId = bankAccount.Id;
|
||||
}
|
||||
else if (!string.IsNullOrWhiteSpace(stipeCustomerPaymentMethodId))
|
||||
{
|
||||
await _stripeAdapter.PaymentMethodAttachAsync(stipeCustomerPaymentMethodId,
|
||||
new PaymentMethodAttachOptions { Customer = customer.Id });
|
||||
defaultPaymentMethodId = stipeCustomerPaymentMethodId;
|
||||
}
|
||||
}
|
||||
|
||||
if (customer.Sources != null)
|
||||
{
|
||||
foreach (var source in customer.Sources.Where(s => s.Id != defaultSourceId))
|
||||
{
|
||||
if (source is BankAccount)
|
||||
{
|
||||
await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id);
|
||||
}
|
||||
else if (source is Card)
|
||||
{
|
||||
await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging(new PaymentMethodListOptions
|
||||
{
|
||||
Customer = customer.Id,
|
||||
Type = "card"
|
||||
});
|
||||
foreach (var cardMethod in cardPaymentMethods.Where(m => m.Id != defaultPaymentMethodId))
|
||||
{
|
||||
await _stripeAdapter.PaymentMethodDetachAsync(cardMethod.Id, new PaymentMethodDetachOptions());
|
||||
}
|
||||
|
||||
await _subscriberService.UpdateTaxInformation(subscriber, TaxInformation.From(taxInfo));
|
||||
|
||||
customer = await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
|
||||
{
|
||||
Metadata = stripeCustomerMetadata,
|
||||
DefaultSource = defaultSourceId,
|
||||
InvoiceSettings = new CustomerInvoiceSettingsOptions
|
||||
{
|
||||
DefaultPaymentMethod = defaultPaymentMethodId,
|
||||
CustomFields =
|
||||
[
|
||||
new CustomerInvoiceSettingsCustomFieldOptions()
|
||||
{
|
||||
Name = subscriber.SubscriberType(),
|
||||
Value = subscriber.GetFormattedInvoiceName()
|
||||
}
|
||||
]
|
||||
},
|
||||
Expand = ["tax", "subscriptions"]
|
||||
});
|
||||
}
|
||||
|
||||
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
|
||||
{
|
||||
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
|
||||
{
|
||||
var subscriptionGetOptions = new SubscriptionGetOptions
|
||||
{
|
||||
Expand = ["customer.tax", "customer.tax_ids"]
|
||||
};
|
||||
var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions);
|
||||
|
||||
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id));
|
||||
var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters);
|
||||
var subscriptionUpdateOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
|
||||
|
||||
if (subscriptionUpdateOptions != null)
|
||||
{
|
||||
_ = await _stripeAdapter.SubscriptionUpdateAsync(
|
||||
subscriber.GatewaySubscriptionId,
|
||||
subscriptionUpdateOptions);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) &&
|
||||
customer.Subscriptions.Any(sub =>
|
||||
sub.Id == subscriber.GatewaySubscriptionId &&
|
||||
!sub.AutomaticTax.Enabled) &&
|
||||
customer.HasTaxLocationVerified())
|
||||
{
|
||||
var subscriptionUpdateOptions = new SubscriptionUpdateOptions
|
||||
{
|
||||
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
|
||||
DefaultTaxRates = []
|
||||
};
|
||||
|
||||
_ = await _stripeAdapter.SubscriptionUpdateAsync(
|
||||
subscriber.GatewaySubscriptionId,
|
||||
subscriptionUpdateOptions);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch
|
||||
{
|
||||
if (braintreeCustomer != null && !hadBtCustomer)
|
||||
{
|
||||
await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id);
|
||||
}
|
||||
throw;
|
||||
}
|
||||
|
||||
return createdCustomer;
|
||||
}
|
||||
|
||||
public async Task<bool> CreditAccountAsync(ISubscriber subscriber, decimal creditAmount)
|
||||
{
|
||||
Customer customer = null;
|
||||
@ -1018,7 +765,7 @@ public class StripePaymentService : IPaymentService
|
||||
var address = customer.Address;
|
||||
var taxId = customer.TaxIds?.FirstOrDefault();
|
||||
|
||||
// Line1 is required, so if missing we're using the subscriber name
|
||||
// Line1 is required, so if missing we're using the subscriber name,
|
||||
// see: https://stripe.com/docs/api/customers/create#create_customer-address-line1
|
||||
if (address != null && string.IsNullOrWhiteSpace(address.Line1))
|
||||
{
|
||||
|
@ -1341,9 +1341,7 @@ public class UserService : UserManager<User>, IUserService, IDisposable
|
||||
var organizationsWithVerifiedUserEmailDomain = await _organizationRepository.GetByVerifiedUserEmailDomainAsync(userId);
|
||||
|
||||
// Organizations must be enabled and able to have verified domains.
|
||||
// TODO: Replace "UseSso" with a new organization ability like "UseOrganizationDomains" (PM-11622).
|
||||
// Verified domains were tied to SSO, so we currently check the "UseSso" organization ability.
|
||||
return organizationsWithVerifiedUserEmailDomain.Where(organization => organization is { Enabled: true, UseSso: true });
|
||||
return organizationsWithVerifiedUserEmailDomain.Where(organization => organization is { Enabled: true, UseOrganizationDomains: true });
|
||||
}
|
||||
|
||||
/// <inheritdoc cref="IsLegacyUser(string)"/>
|
||||
|
19
src/Core/Tools/Models/Data/SendAccessResult.cs
Normal file
19
src/Core/Tools/Models/Data/SendAccessResult.cs
Normal file
@ -0,0 +1,19 @@
|
||||
using Bit.Core.Tools.Entities;
|
||||
|
||||
namespace Bit.Core.Tools.Models.Data;
|
||||
|
||||
/// <summary>
|
||||
/// This enum represents the possible results when attempting to access a <see cref="Send"/>.
|
||||
/// </summary>
|
||||
/// <member>name="Granted">Access is granted for the <see cref="Send"/>.</member>
|
||||
/// <member>name="PasswordRequired">Access is denied, but a password is required to access the <see cref="Send"/>.
|
||||
/// </member>
|
||||
/// <member>name="PasswordInvalid">Access is denied due to an invalid password.</member>
|
||||
/// <member>name="Denied">Access is denied for the <see cref="Send"/>.</member>
|
||||
public enum SendAccessResult
|
||||
{
|
||||
Granted,
|
||||
PasswordRequired,
|
||||
PasswordInvalid,
|
||||
Denied
|
||||
}
|
52
src/Core/Tools/SendFeatures/Commands/AnonymousSendCommand.cs
Normal file
52
src/Core/Tools/SendFeatures/Commands/AnonymousSendCommand.cs
Normal file
@ -0,0 +1,52 @@
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Enums;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
|
||||
using Bit.Core.Tools.Services;
|
||||
|
||||
namespace Bit.Core.Tools.SendFeatures.Commands;
|
||||
|
||||
public class AnonymousSendCommand : IAnonymousSendCommand
|
||||
{
|
||||
private readonly ISendRepository _sendRepository;
|
||||
private readonly ISendFileStorageService _sendFileStorageService;
|
||||
private readonly IPushNotificationService _pushNotificationService;
|
||||
private readonly ISendAuthorizationService _sendAuthorizationService;
|
||||
|
||||
public AnonymousSendCommand(
|
||||
ISendRepository sendRepository,
|
||||
ISendFileStorageService sendFileStorageService,
|
||||
IPushNotificationService pushNotificationService,
|
||||
ISendAuthorizationService sendAuthorizationService
|
||||
)
|
||||
{
|
||||
_sendRepository = sendRepository;
|
||||
_sendFileStorageService = sendFileStorageService;
|
||||
_pushNotificationService = pushNotificationService;
|
||||
_sendAuthorizationService = sendAuthorizationService;
|
||||
}
|
||||
|
||||
// Response: Send, password required, password invalid
|
||||
public async Task<(string, SendAccessResult)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password)
|
||||
{
|
||||
if (send.Type != SendType.File)
|
||||
{
|
||||
throw new BadRequestException("Can only get a download URL for a file type of Send");
|
||||
}
|
||||
|
||||
var result = _sendAuthorizationService.SendCanBeAccessed(send, password);
|
||||
|
||||
if (!result.Equals(SendAccessResult.Granted))
|
||||
{
|
||||
return (null, result);
|
||||
}
|
||||
|
||||
send.AccessCount++;
|
||||
await _sendRepository.ReplaceAsync(send);
|
||||
await _pushNotificationService.PushSyncSendUpdateAsync(send);
|
||||
return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), result);
|
||||
}
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
|
||||
namespace Bit.Core.Tools.SendFeatures.Commands.Interfaces;
|
||||
|
||||
/// <summary>
|
||||
/// AnonymousSendCommand interface provides methods for managing anonymous Sends.
|
||||
/// </summary>
|
||||
public interface IAnonymousSendCommand
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the Send file download URL for a Send object.
|
||||
/// </summary>
|
||||
/// <param name="send"><see cref="Send" /> used to help get file download url and validate file</param>
|
||||
/// <param name="fileId">FileId get file download url</param>
|
||||
/// <param name="password">A hashed and base64-encoded password. This is compared with the send's password to authorize access.</param>
|
||||
/// <returns>Async Task object with Tuple containing the string of download url and <see cref="SendAccessResult" />
|
||||
/// to determine if the user can access send.
|
||||
/// </returns>
|
||||
Task<(string, SendAccessResult)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password);
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
|
||||
namespace Bit.Core.Tools.SendFeatures.Commands.Interfaces;
|
||||
|
||||
/// <summary>
|
||||
/// NonAnonymousSendCommand interface provides methods for managing non-anonymous Sends.
|
||||
/// </summary>
|
||||
public interface INonAnonymousSendCommand
|
||||
{
|
||||
/// <summary>
|
||||
/// Saves a <see cref="Send" /> to the database.
|
||||
/// </summary>
|
||||
/// <param name="send"><see cref="Send" /> that will save to database</param>
|
||||
/// <returns>Task completes as <see cref="Send" /> saves to the database</returns>
|
||||
Task SaveSendAsync(Send send);
|
||||
|
||||
/// <summary>
|
||||
/// Saves the <see cref="Send" /> and <see cref="SendFileData" /> to the database.
|
||||
/// </summary>
|
||||
/// <param name="send"><see cref="Send" /> that will save to the database</param>
|
||||
/// <param name="data"><see cref="SendFileData" /> that will save to file storage</param>
|
||||
/// <param name="fileLength">Length of file help with saving to file storage</param>
|
||||
/// <returns>Task object for async operations with file upload url</returns>
|
||||
Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength);
|
||||
|
||||
/// <summary>
|
||||
/// Upload a file to an existing <see cref="Send" />.
|
||||
/// </summary>
|
||||
/// <param name="stream"><see cref="Stream" /> of file to be uploaded. The <see cref="Stream" /> position
|
||||
/// will be set to 0 before uploading the file.</param>
|
||||
/// <param name="send"><see cref="Send" /> used to help with uploading file</param>
|
||||
/// <returns>Task completes after saving <see cref="Stream" /> and <see cref="Send" /> metadata to the file storage</returns>
|
||||
Task UploadFileToExistingSendAsync(Stream stream, Send send);
|
||||
|
||||
/// <summary>
|
||||
/// Deletes a <see cref="Send" /> from the database and file storage.
|
||||
/// </summary>
|
||||
/// <param name="send"><see cref="Send" /> is used to delete from database and file storage</param>
|
||||
/// <returns>Task completes once <see cref="Send" /> has been deleted from database and file storage.</returns>
|
||||
Task DeleteSendAsync(Send send);
|
||||
|
||||
/// <summary>
|
||||
/// Stores the confirmed file size of a send; when the file size cannot be confirmed, the send is deleted.
|
||||
/// </summary>
|
||||
/// <param name="send">The <see cref="Send" /> this command acts upon</param>
|
||||
/// <returns><see langword="true" /> when the file is confirmed, otherwise <see langword="false" /></returns>
|
||||
/// <remarks>
|
||||
/// When a file size cannot be confirmed, we assume we're working with a rogue client. The send is deleted out of
|
||||
/// an abundance of caution.
|
||||
/// </remarks>
|
||||
Task<bool> ConfirmFileSize(Send send);
|
||||
}
|
180
src/Core/Tools/SendFeatures/Commands/NonAnonymousSendCommand.cs
Normal file
180
src/Core/Tools/SendFeatures/Commands/NonAnonymousSendCommand.cs
Normal file
@ -0,0 +1,180 @@
|
||||
using System.Text.Json;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Enums;
|
||||
using Bit.Core.Tools.Models.Business;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
|
||||
using Bit.Core.Tools.Services;
|
||||
using Bit.Core.Utilities;
|
||||
|
||||
namespace Bit.Core.Tools.SendFeatures.Commands;
|
||||
|
||||
public class NonAnonymousSendCommand : INonAnonymousSendCommand
|
||||
{
|
||||
private readonly ISendRepository _sendRepository;
|
||||
private readonly ISendFileStorageService _sendFileStorageService;
|
||||
private readonly IPushNotificationService _pushNotificationService;
|
||||
private readonly ISendValidationService _sendValidationService;
|
||||
private readonly IReferenceEventService _referenceEventService;
|
||||
private readonly ICurrentContext _currentContext;
|
||||
private readonly ISendCoreHelperService _sendCoreHelperService;
|
||||
|
||||
public NonAnonymousSendCommand(ISendRepository sendRepository,
|
||||
ISendFileStorageService sendFileStorageService,
|
||||
IPushNotificationService pushNotificationService,
|
||||
ISendAuthorizationService sendAuthorizationService,
|
||||
ISendValidationService sendValidationService,
|
||||
IReferenceEventService referenceEventService,
|
||||
ICurrentContext currentContext,
|
||||
ISendCoreHelperService sendCoreHelperService)
|
||||
{
|
||||
_sendRepository = sendRepository;
|
||||
_sendFileStorageService = sendFileStorageService;
|
||||
_pushNotificationService = pushNotificationService;
|
||||
_sendValidationService = sendValidationService;
|
||||
_referenceEventService = referenceEventService;
|
||||
_currentContext = currentContext;
|
||||
_sendCoreHelperService = sendCoreHelperService;
|
||||
}
|
||||
|
||||
public async Task SaveSendAsync(Send send)
|
||||
{
|
||||
// Make sure user can save Sends
|
||||
await _sendValidationService.ValidateUserCanSaveAsync(send.UserId, send);
|
||||
|
||||
if (send.Id == default(Guid))
|
||||
{
|
||||
await _sendRepository.CreateAsync(send);
|
||||
await _pushNotificationService.PushSyncSendCreateAsync(send);
|
||||
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
|
||||
{
|
||||
Id = send.UserId ?? default,
|
||||
Type = ReferenceEventType.SendCreated,
|
||||
Source = ReferenceEventSource.User,
|
||||
SendType = send.Type,
|
||||
MaxAccessCount = send.MaxAccessCount,
|
||||
HasPassword = !string.IsNullOrWhiteSpace(send.Password),
|
||||
SendHasNotes = send.Data?.Contains("Notes"),
|
||||
ClientId = _currentContext.ClientId,
|
||||
ClientVersion = _currentContext.ClientVersion
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
send.RevisionDate = DateTime.UtcNow;
|
||||
await _sendRepository.UpsertAsync(send);
|
||||
await _pushNotificationService.PushSyncSendUpdateAsync(send);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength)
|
||||
{
|
||||
if (send.Type != SendType.File)
|
||||
{
|
||||
throw new BadRequestException("Send is not of type \"file\".");
|
||||
}
|
||||
|
||||
if (fileLength < 1)
|
||||
{
|
||||
throw new BadRequestException("No file data.");
|
||||
}
|
||||
|
||||
var storageBytesRemaining = await _sendValidationService.StorageRemainingForSendAsync(send);
|
||||
|
||||
if (storageBytesRemaining < fileLength)
|
||||
{
|
||||
throw new BadRequestException("Not enough storage available.");
|
||||
}
|
||||
|
||||
var fileId = _sendCoreHelperService.SecureRandomString(32, useUpperCase: false, useSpecial: false);
|
||||
|
||||
try
|
||||
{
|
||||
data.Id = fileId;
|
||||
data.Size = fileLength;
|
||||
data.Validated = false;
|
||||
send.Data = JsonSerializer.Serialize(data,
|
||||
JsonHelpers.IgnoreWritingNull);
|
||||
await SaveSendAsync(send);
|
||||
return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId);
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Clean up since this is not transactional
|
||||
await _sendFileStorageService.DeleteFileAsync(send, fileId);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
public async Task UploadFileToExistingSendAsync(Stream stream, Send send)
|
||||
{
|
||||
if (stream.Position > 0)
|
||||
{
|
||||
stream.Position = 0;
|
||||
}
|
||||
|
||||
if (send?.Data == null)
|
||||
{
|
||||
throw new BadRequestException("Send does not have file data");
|
||||
}
|
||||
|
||||
if (send.Type != SendType.File)
|
||||
{
|
||||
throw new BadRequestException("Not a File Type Send.");
|
||||
}
|
||||
|
||||
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
|
||||
|
||||
if (data.Validated)
|
||||
{
|
||||
throw new BadRequestException("File has already been uploaded.");
|
||||
}
|
||||
|
||||
await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id);
|
||||
|
||||
if (!await ConfirmFileSize(send))
|
||||
{
|
||||
throw new BadRequestException("File received does not match expected file length.");
|
||||
}
|
||||
}
|
||||
public async Task DeleteSendAsync(Send send)
|
||||
{
|
||||
await _sendRepository.DeleteAsync(send);
|
||||
if (send.Type == Enums.SendType.File)
|
||||
{
|
||||
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
|
||||
await _sendFileStorageService.DeleteFileAsync(send, data.Id);
|
||||
}
|
||||
await _pushNotificationService.PushSyncSendDeleteAsync(send);
|
||||
}
|
||||
|
||||
public async Task<bool> ConfirmFileSize(Send send)
|
||||
{
|
||||
var fileData = JsonSerializer.Deserialize<SendFileData>(send.Data);
|
||||
|
||||
var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, SendFileSettingHelper.FILE_SIZE_LEEWAY);
|
||||
|
||||
if (!valid || realSize > SendFileSettingHelper.FILE_SIZE_LEEWAY)
|
||||
{
|
||||
// File reported differs in size from that promised. Must be a rogue client. Delete Send
|
||||
await DeleteSendAsync(send);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Update Send data if necessary
|
||||
if (realSize != fileData.Size)
|
||||
{
|
||||
fileData.Size = realSize.Value;
|
||||
}
|
||||
fileData.Validated = true;
|
||||
send.Data = JsonSerializer.Serialize(fileData,
|
||||
JsonHelpers.IgnoreWritingNull);
|
||||
await SaveSendAsync(send);
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
using Bit.Core.Tools.SendFeatures.Commands;
|
||||
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
|
||||
using Bit.Core.Tools.Services;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace Bit.Core.Tools.SendFeatures;
|
||||
|
||||
public static class SendServiceCollectionExtension
|
||||
{
|
||||
public static void AddSendServices(this IServiceCollection services)
|
||||
{
|
||||
services.AddScoped<INonAnonymousSendCommand, NonAnonymousSendCommand>();
|
||||
services.AddScoped<IAnonymousSendCommand, AnonymousSendCommand>();
|
||||
services.AddScoped<ISendAuthorizationService, SendAuthorizationService>();
|
||||
services.AddScoped<ISendValidationService, SendValidationService>();
|
||||
services.AddScoped<ISendCoreHelperService, SendCoreHelperService>();
|
||||
}
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
|
||||
namespace Bit.Core.Tools.Services;
|
||||
|
||||
/// <summary>
|
||||
/// Send Authorization service is responsible for checking if a Send can be accessed.
|
||||
/// </summary>
|
||||
public interface ISendAuthorizationService
|
||||
{
|
||||
/// <summary>
|
||||
/// Checks if a <see cref="Send" /> can be accessed while updating the <see cref="Send" />, pushing a notification, and sending a reference event.
|
||||
/// </summary>
|
||||
/// <param name="send"><see cref="Send" /> used to determine access</param>
|
||||
/// <param name="password">A hashed and base64-encoded password. This is compared with the send's password to authorize access.</param>
|
||||
/// <returns><see cref="SendAccessResult" /> will be returned to determine if the user can access send.
|
||||
/// </returns>
|
||||
Task<SendAccessResult> AccessAsync(Send send, string password);
|
||||
SendAccessResult SendCanBeAccessed(Send send,
|
||||
string password);
|
||||
|
||||
/// <summary>
|
||||
/// Hashes the password using the password hasher.
|
||||
/// </summary>
|
||||
/// <param name="password">Password to be hashed</param>
|
||||
/// <returns>Hashed password of the password given</returns>
|
||||
string HashPassword(string password);
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
namespace Bit.Core.Tools.Services;
|
||||
|
||||
/// <summary>
|
||||
/// This interface provides helper methods for generating secure random strings. Making
|
||||
/// it easier to mock the service in unit tests.
|
||||
/// </summary>
|
||||
public interface ISendCoreHelperService
|
||||
{
|
||||
/// <summary>
|
||||
/// Securely generates a random string of the specified length.
|
||||
/// </summary>
|
||||
/// <param name="length">Desired string length to be returned</param>
|
||||
/// <param name="useUpperCase">Desired casing for the string</param>
|
||||
/// <param name="useSpecial">Determines if special characters will be used in string</param>
|
||||
/// <returns>A secure random string with the desired parameters</returns>
|
||||
string SecureRandomString(int length, bool useUpperCase, bool useSpecial);
|
||||
}
|
@ -0,0 +1,71 @@
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Tools.Entities;
|
||||
|
||||
namespace Bit.Core.Tools.Services;
|
||||
|
||||
/// <summary>
|
||||
/// Send File Storage Service is responsible for uploading, deleting, and validating files
|
||||
/// whether they are in local storage or in cloud storage.
|
||||
/// </summary>
|
||||
public interface ISendFileStorageService
|
||||
{
|
||||
FileUploadType FileUploadType { get; }
|
||||
/// <summary>
|
||||
/// Uploads a new file to the storage.
|
||||
/// </summary>
|
||||
/// <param name="stream"><see cref="Stream" /> of the file</param>
|
||||
/// <param name="send"><see cref="Send" /> for the file</param>
|
||||
/// <param name="fileId">File id</param>
|
||||
/// <returns>Task completes once <see cref="Stream" /> and <see cref="Send" /> have been saved to the database</returns>
|
||||
Task UploadNewFileAsync(Stream stream, Send send, string fileId);
|
||||
/// <summary>
|
||||
/// Deletes a file from the storage.
|
||||
/// </summary>
|
||||
/// <param name="send"><see cref="Send" /> used to delete file</param>
|
||||
/// <param name="fileId">File id of file to be deleted</param>
|
||||
/// <returns>Task completes once <see cref="Send" /> has been deleted to the database</returns>
|
||||
Task DeleteFileAsync(Send send, string fileId);
|
||||
/// <summary>
|
||||
/// Deletes all files for a specific organization.
|
||||
/// </summary>
|
||||
/// <param name="organizationId"><see cref="Guid" /> used to delete all files pertaining to organization</param>
|
||||
/// <returns>Task completes after running code to delete files by organization id</returns>
|
||||
Task DeleteFilesForOrganizationAsync(Guid organizationId);
|
||||
/// <summary>
|
||||
/// Deletes all files for a specific user.
|
||||
/// </summary>
|
||||
/// <param name="userId"><see cref="Guid" /> used to delete all files pertaining to user</param>
|
||||
/// <returns>Task completes after running code to delete files by user id</returns>
|
||||
Task DeleteFilesForUserAsync(Guid userId);
|
||||
/// <summary>
|
||||
/// Gets the download URL for a file.
|
||||
/// </summary>
|
||||
/// <param name="send"><see cref="Send" /> used to help get download url for file</param>
|
||||
/// <param name="fileId">File id to help get download url for file</param>
|
||||
/// <returns>Download url as a string</returns>
|
||||
Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId);
|
||||
/// <summary>
|
||||
/// Gets the upload URL for a file.
|
||||
/// </summary>
|
||||
/// <param name="send"><see cref="Send" /> used to help get upload url for file </param>
|
||||
/// <param name="fileId">File id to help get upload url for file</param>
|
||||
/// <returns>File upload url as string</returns>
|
||||
Task<string> GetSendFileUploadUrlAsync(Send send, string fileId);
|
||||
/// <summary>
|
||||
/// Validates the file size of a file in the storage.
|
||||
/// </summary>
|
||||
/// <param name="send"><see cref="Send" /> used to help validate file</param>
|
||||
/// <param name="fileId">File id to identify which file to validate</param>
|
||||
/// <param name="expectedFileSize">Expected file size of the file</param>
|
||||
/// <param name="leeway">
|
||||
/// Send file size tolerance in bytes. When an uploaded file's `expectedFileSize`
|
||||
/// is outside of the leeway, the storage operation fails.
|
||||
/// </param>
|
||||
/// <throws>
|
||||
/// ❌ Fill this in with an explanation of the error thrown when `leeway` is incorrect
|
||||
/// </throws>
|
||||
/// <returns>Task object for async operations with Tuple of boolean that determines if file was valid and long that
|
||||
/// the actual file size of the file.
|
||||
/// </returns>
|
||||
Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway);
|
||||
}
|
@ -0,0 +1,35 @@
|
||||
using Bit.Core.Tools.Entities;
|
||||
|
||||
namespace Bit.Core.Tools.Services;
|
||||
|
||||
public interface ISendValidationService
|
||||
{
|
||||
/// <summary>
|
||||
/// Validates a file can be saved by specified user.
|
||||
/// </summary>
|
||||
/// <param name="userId"><see cref="Guid" /> needed to validate file for specific user</param>
|
||||
/// <param name="send"><see cref="Send" /> needed to help validate file</param>
|
||||
/// <returns>Task completes when a conditional statement has been met it will return out of the method or
|
||||
/// throw a BadRequestException.
|
||||
/// </returns>
|
||||
Task ValidateUserCanSaveAsync(Guid? userId, Send send);
|
||||
|
||||
/// <summary>
|
||||
/// Validates a file can be saved by specified user with different policy based on feature flag
|
||||
/// </summary>
|
||||
/// <param name="userId"><see cref="Guid" /> needed to validate file for specific user</param>
|
||||
/// <param name="send"><see cref="Send" /> needed to help validate file</param>
|
||||
/// <returns>Task completes when a conditional statement has been met it will return out of the method or
|
||||
/// throw a BadRequestException.
|
||||
/// </returns>
|
||||
Task ValidateUserCanSaveAsync_vNext(Guid? userId, Send send);
|
||||
|
||||
/// <summary>
|
||||
/// Calculates the remaining storage for a Send.
|
||||
/// </summary>
|
||||
/// <param name="send"><see cref="Send" /> needed to help calculate remaining storage</param>
|
||||
/// <returns>Long with the remaining bytes for storage or will throw a BadRequestException if user cannot access
|
||||
/// file or email is not verified.
|
||||
/// </returns>
|
||||
Task<long> StorageRemainingForSendAsync(Send send);
|
||||
}
|
101
src/Core/Tools/SendFeatures/Services/SendAuthorizationService.cs
Normal file
101
src/Core/Tools/SendFeatures/Services/SendAuthorizationService.cs
Normal file
@ -0,0 +1,101 @@
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Enums;
|
||||
using Bit.Core.Tools.Models.Business;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Microsoft.AspNetCore.Identity;
|
||||
|
||||
namespace Bit.Core.Tools.Services;
|
||||
|
||||
public class SendAuthorizationService : ISendAuthorizationService
|
||||
{
|
||||
private readonly ISendRepository _sendRepository;
|
||||
private readonly IPasswordHasher<User> _passwordHasher;
|
||||
private readonly IPushNotificationService _pushNotificationService;
|
||||
private readonly IReferenceEventService _referenceEventService;
|
||||
private readonly ICurrentContext _currentContext;
|
||||
|
||||
public SendAuthorizationService(
|
||||
ISendRepository sendRepository,
|
||||
IPasswordHasher<User> passwordHasher,
|
||||
IPushNotificationService pushNotificationService,
|
||||
IReferenceEventService referenceEventService,
|
||||
ICurrentContext currentContext)
|
||||
{
|
||||
_sendRepository = sendRepository;
|
||||
_passwordHasher = passwordHasher;
|
||||
_pushNotificationService = pushNotificationService;
|
||||
_referenceEventService = referenceEventService;
|
||||
_currentContext = currentContext;
|
||||
}
|
||||
|
||||
public SendAccessResult SendCanBeAccessed(Send send,
|
||||
string password)
|
||||
{
|
||||
var now = DateTime.UtcNow;
|
||||
if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount ||
|
||||
send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled ||
|
||||
send.DeletionDate < now)
|
||||
{
|
||||
return SendAccessResult.Denied;
|
||||
}
|
||||
if (!string.IsNullOrWhiteSpace(send.Password))
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(password))
|
||||
{
|
||||
return SendAccessResult.PasswordRequired;
|
||||
}
|
||||
var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password);
|
||||
if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded)
|
||||
{
|
||||
send.Password = HashPassword(password);
|
||||
}
|
||||
if (passwordResult == PasswordVerificationResult.Failed)
|
||||
{
|
||||
return SendAccessResult.PasswordInvalid;
|
||||
}
|
||||
}
|
||||
|
||||
return SendAccessResult.Granted;
|
||||
}
|
||||
|
||||
public async Task<SendAccessResult> AccessAsync(Send sendToBeAccessed, string password)
|
||||
{
|
||||
var accessResult = SendCanBeAccessed(sendToBeAccessed, password);
|
||||
|
||||
if (!accessResult.Equals(SendAccessResult.Granted))
|
||||
{
|
||||
return accessResult;
|
||||
}
|
||||
|
||||
if (sendToBeAccessed.Type != SendType.File)
|
||||
{
|
||||
// File sends are incremented during file download
|
||||
sendToBeAccessed.AccessCount++;
|
||||
}
|
||||
|
||||
await _sendRepository.ReplaceAsync(sendToBeAccessed);
|
||||
await _pushNotificationService.PushSyncSendUpdateAsync(sendToBeAccessed);
|
||||
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
|
||||
{
|
||||
Id = sendToBeAccessed.UserId ?? default,
|
||||
Type = ReferenceEventType.SendAccessed,
|
||||
Source = ReferenceEventSource.User,
|
||||
SendType = sendToBeAccessed.Type,
|
||||
MaxAccessCount = sendToBeAccessed.MaxAccessCount,
|
||||
HasPassword = !string.IsNullOrWhiteSpace(sendToBeAccessed.Password),
|
||||
SendHasNotes = sendToBeAccessed.Data?.Contains("Notes"),
|
||||
ClientId = _currentContext.ClientId,
|
||||
ClientVersion = _currentContext.ClientVersion
|
||||
});
|
||||
return accessResult;
|
||||
}
|
||||
|
||||
public string HashPassword(string password)
|
||||
{
|
||||
return _passwordHasher.HashPassword(new User(), password);
|
||||
}
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
using Bit.Core.Utilities;
|
||||
|
||||
namespace Bit.Core.Tools.Services;
|
||||
|
||||
public class SendCoreHelperService : ISendCoreHelperService
|
||||
{
|
||||
public string SecureRandomString(int length, bool useUpperCase, bool useSpecial)
|
||||
{
|
||||
return CoreHelpers.SecureRandomString(length, upper: useUpperCase, special: useSpecial);
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
using Bit.Core.Tools.Entities;
|
||||
|
||||
namespace Bit.Core.Tools.SendFeatures;
|
||||
|
||||
/// <summary>
|
||||
/// SendFileSettingHelper is a static class that provides constants and helper methods (if needed) for managing file
|
||||
/// settings.
|
||||
/// </summary>
|
||||
public static class SendFileSettingHelper
|
||||
{
|
||||
/// <summary>
|
||||
/// The leeway for the file size. This is the calculated 1 megabyte of cushion when doing comparisons of file sizes
|
||||
/// within the system.
|
||||
/// </summary>
|
||||
public const long FILE_SIZE_LEEWAY = 1024L * 1024L; // 1MB
|
||||
/// <summary>
|
||||
/// The maximum file size for a file uploaded in a <see cref="Send" />. Units are calculated in bytes but
|
||||
/// represent 501 megabytes. 1 megabyte is added for cushion to account for file size.
|
||||
/// </summary>
|
||||
public const long MAX_FILE_SIZE = Constants.FileSize501mb;
|
||||
|
||||
/// <summary>
|
||||
/// String of the expected file size and to be used when needing to communicate the file size to the client/user.
|
||||
/// </summary>
|
||||
public const string MAX_FILE_SIZE_READABLE = "500 MB";
|
||||
}
|
142
src/Core/Tools/SendFeatures/Services/SendValidationService.cs
Normal file
142
src/Core/Tools/SendFeatures/Services/SendValidationService.cs
Normal file
@ -0,0 +1,142 @@
|
||||
using Bit.Core.AdminConsole.Enums;
|
||||
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
|
||||
using Bit.Core.AdminConsole.Services;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Settings;
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Utilities;
|
||||
|
||||
namespace Bit.Core.Tools.Services;
|
||||
|
||||
public class SendValidationService : ISendValidationService
|
||||
{
|
||||
|
||||
private readonly IUserRepository _userRepository;
|
||||
private readonly IOrganizationRepository _organizationRepository;
|
||||
private readonly IPolicyService _policyService;
|
||||
private readonly IFeatureService _featureService;
|
||||
private readonly IUserService _userService;
|
||||
private readonly GlobalSettings _globalSettings;
|
||||
private readonly ICurrentContext _currentContext;
|
||||
private readonly IPolicyRequirementQuery _policyRequirementQuery;
|
||||
|
||||
|
||||
|
||||
public SendValidationService(
|
||||
IUserRepository userRepository,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IPolicyService policyService,
|
||||
IFeatureService featureService,
|
||||
IUserService userService,
|
||||
IPolicyRequirementQuery policyRequirementQuery,
|
||||
GlobalSettings globalSettings,
|
||||
|
||||
ICurrentContext currentContext)
|
||||
{
|
||||
_userRepository = userRepository;
|
||||
_organizationRepository = organizationRepository;
|
||||
_policyService = policyService;
|
||||
_featureService = featureService;
|
||||
_userService = userService;
|
||||
_policyRequirementQuery = policyRequirementQuery;
|
||||
_globalSettings = globalSettings;
|
||||
_currentContext = currentContext;
|
||||
}
|
||||
|
||||
public async Task ValidateUserCanSaveAsync(Guid? userId, Send send)
|
||||
{
|
||||
if (_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements))
|
||||
{
|
||||
await ValidateUserCanSaveAsync_vNext(userId, send);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
var anyDisableSendPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(userId.Value,
|
||||
PolicyType.DisableSend);
|
||||
if (anyDisableSendPolicies)
|
||||
{
|
||||
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
|
||||
}
|
||||
|
||||
if (send.HideEmail.GetValueOrDefault())
|
||||
{
|
||||
var sendOptionsPolicies = await _policyService.GetPoliciesApplicableToUserAsync(userId.Value, PolicyType.SendOptions);
|
||||
if (sendOptionsPolicies.Any(p => CoreHelpers.LoadClassFromJsonData<SendOptionsPolicyData>(p.PolicyData)?.DisableHideEmail ?? false))
|
||||
{
|
||||
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async Task ValidateUserCanSaveAsync_vNext(Guid? userId, Send send)
|
||||
{
|
||||
if (!userId.HasValue)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
var disableSendRequirement = await _policyRequirementQuery.GetAsync<DisableSendPolicyRequirement>(userId.Value);
|
||||
if (disableSendRequirement.DisableSend)
|
||||
{
|
||||
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
|
||||
}
|
||||
|
||||
var sendOptionsRequirement = await _policyRequirementQuery.GetAsync<SendOptionsPolicyRequirement>(userId.Value);
|
||||
if (sendOptionsRequirement.DisableHideEmail && send.HideEmail.GetValueOrDefault())
|
||||
{
|
||||
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<long> StorageRemainingForSendAsync(Send send)
|
||||
{
|
||||
var storageBytesRemaining = 0L;
|
||||
if (send.UserId.HasValue)
|
||||
{
|
||||
var user = await _userRepository.GetByIdAsync(send.UserId.Value);
|
||||
if (!await _userService.CanAccessPremium(user))
|
||||
{
|
||||
throw new BadRequestException("You must have premium status to use file Sends.");
|
||||
}
|
||||
|
||||
if (!user.EmailVerified)
|
||||
{
|
||||
throw new BadRequestException("You must confirm your email to use file Sends.");
|
||||
}
|
||||
|
||||
if (user.Premium)
|
||||
{
|
||||
storageBytesRemaining = user.StorageBytesRemaining();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Users that get access to file storage/premium from their organization get the default
|
||||
// 1 GB max storage.
|
||||
short limit = _globalSettings.SelfHosted ? (short)10240 : (short)1;
|
||||
storageBytesRemaining = user.StorageBytesRemaining(limit);
|
||||
}
|
||||
}
|
||||
else if (send.OrganizationId.HasValue)
|
||||
{
|
||||
var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
|
||||
if (!org.MaxStorageGb.HasValue)
|
||||
{
|
||||
throw new BadRequestException("This organization cannot use file sends.");
|
||||
}
|
||||
|
||||
storageBytesRemaining = org.StorageBytesRemaining();
|
||||
}
|
||||
|
||||
return storageBytesRemaining;
|
||||
}
|
||||
}
|
@ -1,16 +0,0 @@
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
|
||||
namespace Bit.Core.Tools.Services;
|
||||
|
||||
public interface ISendService
|
||||
{
|
||||
Task DeleteSendAsync(Send send);
|
||||
Task SaveSendAsync(Send send);
|
||||
Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength);
|
||||
Task UploadFileToExistingSendAsync(Stream stream, Send send);
|
||||
Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password);
|
||||
string HashPassword(string password);
|
||||
Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password);
|
||||
Task<bool> ValidateSendFile(Send send);
|
||||
}
|
@ -1,16 +0,0 @@
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Tools.Entities;
|
||||
|
||||
namespace Bit.Core.Tools.Services;
|
||||
|
||||
public interface ISendFileStorageService
|
||||
{
|
||||
FileUploadType FileUploadType { get; }
|
||||
Task UploadNewFileAsync(Stream stream, Send send, string fileId);
|
||||
Task DeleteFileAsync(Send send, string fileId);
|
||||
Task DeleteFilesForOrganizationAsync(Guid organizationId);
|
||||
Task DeleteFilesForUserAsync(Guid userId);
|
||||
Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId);
|
||||
Task<string> GetSendFileUploadUrlAsync(Send send, string fileId);
|
||||
Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway);
|
||||
}
|
@ -1,383 +0,0 @@
|
||||
using System.Text.Json;
|
||||
using Bit.Core.AdminConsole.Enums;
|
||||
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
|
||||
using Bit.Core.AdminConsole.Services;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Settings;
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Enums;
|
||||
using Bit.Core.Tools.Models.Business;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Bit.Core.Utilities;
|
||||
using Microsoft.AspNetCore.Identity;
|
||||
|
||||
namespace Bit.Core.Tools.Services;
|
||||
|
||||
public class SendService : ISendService
|
||||
{
|
||||
public const long MAX_FILE_SIZE = Constants.FileSize501mb;
|
||||
public const string MAX_FILE_SIZE_READABLE = "500 MB";
|
||||
private readonly ISendRepository _sendRepository;
|
||||
private readonly IUserRepository _userRepository;
|
||||
private readonly IPolicyService _policyService;
|
||||
private readonly IUserService _userService;
|
||||
private readonly IOrganizationRepository _organizationRepository;
|
||||
private readonly ISendFileStorageService _sendFileStorageService;
|
||||
private readonly IPasswordHasher<User> _passwordHasher;
|
||||
private readonly IPushNotificationService _pushService;
|
||||
private readonly IReferenceEventService _referenceEventService;
|
||||
private readonly GlobalSettings _globalSettings;
|
||||
private readonly ICurrentContext _currentContext;
|
||||
private readonly IPolicyRequirementQuery _policyRequirementQuery;
|
||||
private readonly IFeatureService _featureService;
|
||||
|
||||
private const long _fileSizeLeeway = 1024L * 1024L; // 1MB
|
||||
|
||||
public SendService(
|
||||
ISendRepository sendRepository,
|
||||
IUserRepository userRepository,
|
||||
IUserService userService,
|
||||
IOrganizationRepository organizationRepository,
|
||||
ISendFileStorageService sendFileStorageService,
|
||||
IPasswordHasher<User> passwordHasher,
|
||||
IPushNotificationService pushService,
|
||||
IReferenceEventService referenceEventService,
|
||||
GlobalSettings globalSettings,
|
||||
IPolicyService policyService,
|
||||
ICurrentContext currentContext,
|
||||
IPolicyRequirementQuery policyRequirementQuery,
|
||||
IFeatureService featureService)
|
||||
{
|
||||
_sendRepository = sendRepository;
|
||||
_userRepository = userRepository;
|
||||
_userService = userService;
|
||||
_policyService = policyService;
|
||||
_organizationRepository = organizationRepository;
|
||||
_sendFileStorageService = sendFileStorageService;
|
||||
_passwordHasher = passwordHasher;
|
||||
_pushService = pushService;
|
||||
_referenceEventService = referenceEventService;
|
||||
_globalSettings = globalSettings;
|
||||
_currentContext = currentContext;
|
||||
_policyRequirementQuery = policyRequirementQuery;
|
||||
_featureService = featureService;
|
||||
}
|
||||
|
||||
public async Task SaveSendAsync(Send send)
|
||||
{
|
||||
// Make sure user can save Sends
|
||||
await ValidateUserCanSaveAsync(send.UserId, send);
|
||||
|
||||
if (send.Id == default(Guid))
|
||||
{
|
||||
await _sendRepository.CreateAsync(send);
|
||||
await _pushService.PushSyncSendCreateAsync(send);
|
||||
await RaiseReferenceEventAsync(send, ReferenceEventType.SendCreated);
|
||||
}
|
||||
else
|
||||
{
|
||||
send.RevisionDate = DateTime.UtcNow;
|
||||
await _sendRepository.UpsertAsync(send);
|
||||
await _pushService.PushSyncSendUpdateAsync(send);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength)
|
||||
{
|
||||
if (send.Type != SendType.File)
|
||||
{
|
||||
throw new BadRequestException("Send is not of type \"file\".");
|
||||
}
|
||||
|
||||
if (fileLength < 1)
|
||||
{
|
||||
throw new BadRequestException("No file data.");
|
||||
}
|
||||
|
||||
var storageBytesRemaining = await StorageRemainingForSendAsync(send);
|
||||
|
||||
if (storageBytesRemaining < fileLength)
|
||||
{
|
||||
throw new BadRequestException("Not enough storage available.");
|
||||
}
|
||||
|
||||
var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false);
|
||||
|
||||
try
|
||||
{
|
||||
data.Id = fileId;
|
||||
data.Size = fileLength;
|
||||
data.Validated = false;
|
||||
send.Data = JsonSerializer.Serialize(data,
|
||||
JsonHelpers.IgnoreWritingNull);
|
||||
await SaveSendAsync(send);
|
||||
return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId);
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Clean up since this is not transactional
|
||||
await _sendFileStorageService.DeleteFileAsync(send, fileId);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
public async Task UploadFileToExistingSendAsync(Stream stream, Send send)
|
||||
{
|
||||
if (send?.Data == null)
|
||||
{
|
||||
throw new BadRequestException("Send does not have file data");
|
||||
}
|
||||
|
||||
if (send.Type != SendType.File)
|
||||
{
|
||||
throw new BadRequestException("Not a File Type Send.");
|
||||
}
|
||||
|
||||
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
|
||||
|
||||
if (data.Validated)
|
||||
{
|
||||
throw new BadRequestException("File has already been uploaded.");
|
||||
}
|
||||
|
||||
await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id);
|
||||
|
||||
if (!await ValidateSendFile(send))
|
||||
{
|
||||
throw new BadRequestException("File received does not match expected file length.");
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<bool> ValidateSendFile(Send send)
|
||||
{
|
||||
var fileData = JsonSerializer.Deserialize<SendFileData>(send.Data);
|
||||
|
||||
var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, _fileSizeLeeway);
|
||||
|
||||
if (!valid || realSize > MAX_FILE_SIZE)
|
||||
{
|
||||
// File reported differs in size from that promised. Must be a rogue client. Delete Send
|
||||
await DeleteSendAsync(send);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Update Send data if necessary
|
||||
if (realSize != fileData.Size)
|
||||
{
|
||||
fileData.Size = realSize.Value;
|
||||
}
|
||||
fileData.Validated = true;
|
||||
send.Data = JsonSerializer.Serialize(fileData,
|
||||
JsonHelpers.IgnoreWritingNull);
|
||||
await SaveSendAsync(send);
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
public async Task DeleteSendAsync(Send send)
|
||||
{
|
||||
await _sendRepository.DeleteAsync(send);
|
||||
if (send.Type == Enums.SendType.File)
|
||||
{
|
||||
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
|
||||
await _sendFileStorageService.DeleteFileAsync(send, data.Id);
|
||||
}
|
||||
await _pushService.PushSyncSendDeleteAsync(send);
|
||||
}
|
||||
|
||||
public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send,
|
||||
string password)
|
||||
{
|
||||
var now = DateTime.UtcNow;
|
||||
if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount ||
|
||||
send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled ||
|
||||
send.DeletionDate < now)
|
||||
{
|
||||
return (false, false, false);
|
||||
}
|
||||
if (!string.IsNullOrWhiteSpace(send.Password))
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(password))
|
||||
{
|
||||
return (false, true, false);
|
||||
}
|
||||
var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password);
|
||||
if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded)
|
||||
{
|
||||
send.Password = HashPassword(password);
|
||||
}
|
||||
if (passwordResult == PasswordVerificationResult.Failed)
|
||||
{
|
||||
return (false, false, true);
|
||||
}
|
||||
}
|
||||
|
||||
return (true, false, false);
|
||||
}
|
||||
|
||||
// Response: Send, password required, password invalid
|
||||
public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password)
|
||||
{
|
||||
if (send.Type != SendType.File)
|
||||
{
|
||||
throw new BadRequestException("Can only get a download URL for a file type of Send");
|
||||
}
|
||||
|
||||
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
|
||||
|
||||
if (!grantAccess)
|
||||
{
|
||||
return (null, passwordRequired, passwordInvalid);
|
||||
}
|
||||
|
||||
send.AccessCount++;
|
||||
await _sendRepository.ReplaceAsync(send);
|
||||
await _pushService.PushSyncSendUpdateAsync(send);
|
||||
return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false);
|
||||
}
|
||||
|
||||
// Response: Send, password required, password invalid
|
||||
public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password)
|
||||
{
|
||||
var send = await _sendRepository.GetByIdAsync(sendId);
|
||||
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
|
||||
|
||||
if (!grantAccess)
|
||||
{
|
||||
return (null, passwordRequired, passwordInvalid);
|
||||
}
|
||||
|
||||
// TODO: maybe move this to a simple ++ sproc?
|
||||
if (send.Type != SendType.File)
|
||||
{
|
||||
// File sends are incremented during file download
|
||||
send.AccessCount++;
|
||||
}
|
||||
|
||||
await _sendRepository.ReplaceAsync(send);
|
||||
await _pushService.PushSyncSendUpdateAsync(send);
|
||||
await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed);
|
||||
return (send, false, false);
|
||||
}
|
||||
|
||||
private async Task RaiseReferenceEventAsync(Send send, ReferenceEventType eventType)
|
||||
{
|
||||
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
|
||||
{
|
||||
Id = send.UserId ?? default,
|
||||
Type = eventType,
|
||||
Source = ReferenceEventSource.User,
|
||||
SendType = send.Type,
|
||||
MaxAccessCount = send.MaxAccessCount,
|
||||
HasPassword = !string.IsNullOrWhiteSpace(send.Password),
|
||||
SendHasNotes = send.Data?.Contains("Notes"),
|
||||
ClientId = _currentContext.ClientId,
|
||||
ClientVersion = _currentContext.ClientVersion
|
||||
});
|
||||
}
|
||||
|
||||
public string HashPassword(string password)
|
||||
{
|
||||
return _passwordHasher.HashPassword(new User(), password);
|
||||
}
|
||||
|
||||
private async Task ValidateUserCanSaveAsync(Guid? userId, Send send)
|
||||
{
|
||||
if (_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements))
|
||||
{
|
||||
await ValidateUserCanSaveAsync_vNext(userId, send);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
var anyDisableSendPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(userId.Value,
|
||||
PolicyType.DisableSend);
|
||||
if (anyDisableSendPolicies)
|
||||
{
|
||||
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
|
||||
}
|
||||
|
||||
if (send.HideEmail.GetValueOrDefault())
|
||||
{
|
||||
var sendOptionsPolicies = await _policyService.GetPoliciesApplicableToUserAsync(userId.Value, PolicyType.SendOptions);
|
||||
if (sendOptionsPolicies.Any(p => CoreHelpers.LoadClassFromJsonData<SendOptionsPolicyData>(p.PolicyData)?.DisableHideEmail ?? false))
|
||||
{
|
||||
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ValidateUserCanSaveAsync_vNext(Guid? userId, Send send)
|
||||
{
|
||||
if (!userId.HasValue)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
var disableSendRequirement = await _policyRequirementQuery.GetAsync<DisableSendPolicyRequirement>(userId.Value);
|
||||
if (disableSendRequirement.DisableSend)
|
||||
{
|
||||
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
|
||||
}
|
||||
|
||||
var sendOptionsRequirement = await _policyRequirementQuery.GetAsync<SendOptionsPolicyRequirement>(userId.Value);
|
||||
if (sendOptionsRequirement.DisableHideEmail && send.HideEmail.GetValueOrDefault())
|
||||
{
|
||||
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<long> StorageRemainingForSendAsync(Send send)
|
||||
{
|
||||
var storageBytesRemaining = 0L;
|
||||
if (send.UserId.HasValue)
|
||||
{
|
||||
var user = await _userRepository.GetByIdAsync(send.UserId.Value);
|
||||
if (!await _userService.CanAccessPremium(user))
|
||||
{
|
||||
throw new BadRequestException("You must have premium status to use file Sends.");
|
||||
}
|
||||
|
||||
if (!user.EmailVerified)
|
||||
{
|
||||
throw new BadRequestException("You must confirm your email to use file Sends.");
|
||||
}
|
||||
|
||||
if (user.Premium)
|
||||
{
|
||||
storageBytesRemaining = user.StorageBytesRemaining();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Users that get access to file storage/premium from their organization get the default
|
||||
// 1 GB max storage.
|
||||
storageBytesRemaining = user.StorageBytesRemaining(
|
||||
_globalSettings.SelfHosted ? (short)10240 : (short)1);
|
||||
}
|
||||
}
|
||||
else if (send.OrganizationId.HasValue)
|
||||
{
|
||||
var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
|
||||
if (!org.MaxStorageGb.HasValue)
|
||||
{
|
||||
throw new BadRequestException("This organization cannot use file sends.");
|
||||
}
|
||||
|
||||
storageBytesRemaining = org.StorageBytesRemaining();
|
||||
}
|
||||
|
||||
return storageBytesRemaining;
|
||||
}
|
||||
}
|
@ -64,12 +64,6 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator<ResourceOwner
|
||||
|
||||
public async Task ValidateAsync(ResourceOwnerPasswordValidationContext context)
|
||||
{
|
||||
if (!AuthEmailHeaderIsValid(context))
|
||||
{
|
||||
context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant,
|
||||
"Auth-Email header invalid.");
|
||||
return;
|
||||
}
|
||||
|
||||
var user = await _userManager.FindByEmailAsync(context.UserName.ToLowerInvariant());
|
||||
// We want to keep this device around incase the device is new for the user
|
||||
@ -168,29 +162,4 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator<ResourceOwner
|
||||
return context.Result.Subject;
|
||||
}
|
||||
|
||||
private bool AuthEmailHeaderIsValid(ResourceOwnerPasswordValidationContext context)
|
||||
{
|
||||
if (_currentContext.HttpContext.Request.Headers.TryGetValue("Auth-Email", out var authEmailHeader))
|
||||
{
|
||||
try
|
||||
{
|
||||
var authEmailDecoded = CoreHelpers.Base64UrlDecodeString(authEmailHeader);
|
||||
if (authEmailDecoded != context.UserName)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
catch (Exception e) when (e is InvalidOperationException || e is FormatException)
|
||||
{
|
||||
// Invalid B64 encoding
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ public class UserDecryptionOptionsBuilder : IUserDecryptionOptionsBuilder
|
||||
private readonly ICurrentContext _currentContext;
|
||||
private readonly IDeviceRepository _deviceRepository;
|
||||
private readonly IOrganizationUserRepository _organizationUserRepository;
|
||||
private readonly ILoginApprovingClientTypes _loginApprovingClientTypes;
|
||||
|
||||
private UserDecryptionOptions _options = new UserDecryptionOptions();
|
||||
private User? _user;
|
||||
@ -31,12 +32,14 @@ public class UserDecryptionOptionsBuilder : IUserDecryptionOptionsBuilder
|
||||
public UserDecryptionOptionsBuilder(
|
||||
ICurrentContext currentContext,
|
||||
IDeviceRepository deviceRepository,
|
||||
IOrganizationUserRepository organizationUserRepository
|
||||
IOrganizationUserRepository organizationUserRepository,
|
||||
ILoginApprovingClientTypes loginApprovingClientTypes
|
||||
)
|
||||
{
|
||||
_currentContext = currentContext;
|
||||
_deviceRepository = deviceRepository;
|
||||
_organizationUserRepository = organizationUserRepository;
|
||||
_loginApprovingClientTypes = loginApprovingClientTypes;
|
||||
}
|
||||
|
||||
public IUserDecryptionOptionsBuilder ForUser(User user)
|
||||
@ -119,8 +122,7 @@ public class UserDecryptionOptionsBuilder : IUserDecryptionOptionsBuilder
|
||||
// Checks if the current user has any devices that are capable of approving login with device requests except for
|
||||
// their current device.
|
||||
// NOTE: this doesn't check for if the users have configured the devices to be capable of approving requests as that is a client side setting.
|
||||
hasLoginApprovingDevice = allDevices
|
||||
.Any(d => d.Identifier != _device.Identifier && LoginApprovingClientTypes.TypesThatCanApprove.Contains(DeviceTypes.ToClientType(d.Type)));
|
||||
hasLoginApprovingDevice = allDevices.Any(d => d.Identifier != _device.Identifier && _loginApprovingClientTypes.TypesThatCanApprove.Contains(DeviceTypes.ToClientType(d.Type)));
|
||||
}
|
||||
|
||||
// Determine if user has manage reset password permission as post sso logic requires it for forcing users with this permission to set a MP
|
||||
|
@ -6,5 +6,12 @@ public class RegisterFinishResponseModel : ResponseModel
|
||||
{
|
||||
public RegisterFinishResponseModel()
|
||||
: base("registerFinish")
|
||||
{ }
|
||||
{
|
||||
// We are setting this to an empty string so that old mobile clients don't break, as they reqiure a non-null value.
|
||||
// This will be cleaned up in https://bitwarden.atlassian.net/browse/PM-21720.
|
||||
CaptchaBypassToken = string.Empty;
|
||||
}
|
||||
|
||||
public string CaptchaBypassToken { get; set; }
|
||||
|
||||
}
|
||||
|
@ -1,22 +1,39 @@
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Services;
|
||||
|
||||
namespace Bit.Identity.Utilities;
|
||||
|
||||
public static class LoginApprovingClientTypes
|
||||
public interface ILoginApprovingClientTypes
|
||||
{
|
||||
private static readonly IReadOnlyCollection<ClientType> _clientTypesThatCanApprove;
|
||||
IReadOnlyCollection<ClientType> TypesThatCanApprove { get; }
|
||||
}
|
||||
|
||||
static LoginApprovingClientTypes()
|
||||
public class LoginApprovingClientTypes : ILoginApprovingClientTypes
|
||||
{
|
||||
public LoginApprovingClientTypes(
|
||||
IFeatureService featureService)
|
||||
{
|
||||
var clientTypes = new List<ClientType>
|
||||
if (featureService.IsEnabled(FeatureFlagKeys.BrowserExtensionLoginApproval))
|
||||
{
|
||||
ClientType.Desktop,
|
||||
ClientType.Mobile,
|
||||
ClientType.Web,
|
||||
ClientType.Browser,
|
||||
};
|
||||
_clientTypesThatCanApprove = clientTypes.AsReadOnly();
|
||||
TypesThatCanApprove = new List<ClientType>
|
||||
{
|
||||
ClientType.Desktop,
|
||||
ClientType.Mobile,
|
||||
ClientType.Web,
|
||||
ClientType.Browser,
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
TypesThatCanApprove = new List<ClientType>
|
||||
{
|
||||
ClientType.Desktop,
|
||||
ClientType.Mobile,
|
||||
ClientType.Web,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public static IReadOnlyCollection<ClientType> TypesThatCanApprove => _clientTypesThatCanApprove;
|
||||
public IReadOnlyCollection<ClientType> TypesThatCanApprove { get; }
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ public static class ServiceCollectionExtensions
|
||||
services.AddTransient<IUserDecryptionOptionsBuilder, UserDecryptionOptionsBuilder>();
|
||||
services.AddTransient<IDeviceValidator, DeviceValidator>();
|
||||
services.AddTransient<ITwoFactorAuthenticationValidator, TwoFactorAuthenticationValidator>();
|
||||
services.AddTransient<ILoginApprovingClientTypes, LoginApprovingClientTypes>();
|
||||
|
||||
var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalIdentity);
|
||||
var identityServerBuilder = services
|
||||
|
@ -43,6 +43,7 @@ using Bit.Core.Settings;
|
||||
using Bit.Core.Tokens;
|
||||
using Bit.Core.Tools.ImportFeatures;
|
||||
using Bit.Core.Tools.ReportFeatures;
|
||||
using Bit.Core.Tools.SendFeatures;
|
||||
using Bit.Core.Tools.Services;
|
||||
using Bit.Core.Utilities;
|
||||
using Bit.Core.Vault;
|
||||
@ -123,7 +124,7 @@ public static class ServiceCollectionExtensions
|
||||
services.AddScoped<ISsoConfigService, SsoConfigService>();
|
||||
services.AddScoped<IAuthRequestService, AuthRequestService>();
|
||||
services.AddScoped<IDuoUniversalTokenService, DuoUniversalTokenService>();
|
||||
services.AddScoped<ISendService, SendService>();
|
||||
services.AddScoped<ISendAuthorizationService, SendAuthorizationService>();
|
||||
services.AddLoginServices();
|
||||
services.AddScoped<IOrganizationDomainService, OrganizationDomainService>();
|
||||
services.AddVaultServices();
|
||||
@ -132,6 +133,7 @@ public static class ServiceCollectionExtensions
|
||||
services.AddNotificationCenterServices();
|
||||
services.AddPlatformServices();
|
||||
services.AddImportServices();
|
||||
services.AddSendServices();
|
||||
}
|
||||
|
||||
public static void AddTokenizers(this IServiceCollection services)
|
||||
|
@ -23,11 +23,11 @@ public class SendRotationValidatorTests
|
||||
public async Task ValidateAsync_Success()
|
||||
{
|
||||
// Arrange
|
||||
var sendService = Substitute.For<ISendService>();
|
||||
var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
|
||||
var sendRepository = Substitute.For<ISendRepository>();
|
||||
|
||||
var sut = new SendRotationValidator(
|
||||
sendService,
|
||||
sendAuthorizationService,
|
||||
sendRepository
|
||||
);
|
||||
|
||||
@ -52,11 +52,11 @@ public class SendRotationValidatorTests
|
||||
public async Task ValidateAsync_SendNotReturnedFromRepository_NotIncludedInOutput()
|
||||
{
|
||||
// Arrange
|
||||
var sendService = Substitute.For<ISendService>();
|
||||
var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
|
||||
var sendRepository = Substitute.For<ISendRepository>();
|
||||
|
||||
var sut = new SendRotationValidator(
|
||||
sendService,
|
||||
sendAuthorizationService,
|
||||
sendRepository
|
||||
);
|
||||
|
||||
@ -76,11 +76,11 @@ public class SendRotationValidatorTests
|
||||
public async Task ValidateAsync_InputMissingUserSend_Throws()
|
||||
{
|
||||
// Arrange
|
||||
var sendService = Substitute.For<ISendService>();
|
||||
var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
|
||||
var sendRepository = Substitute.For<ISendRepository>();
|
||||
|
||||
var sut = new SendRotationValidator(
|
||||
sendService,
|
||||
sendAuthorizationService,
|
||||
sendRepository
|
||||
);
|
||||
|
||||
|
@ -10,7 +10,9 @@ using Bit.Core.Services;
|
||||
using Bit.Core.Settings;
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Enums;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
|
||||
using Bit.Core.Tools.Services;
|
||||
using Bit.Core.Utilities;
|
||||
using Microsoft.AspNetCore.Mvc;
|
||||
@ -26,7 +28,9 @@ public class SendsControllerTests : IDisposable
|
||||
private readonly GlobalSettings _globalSettings;
|
||||
private readonly IUserService _userService;
|
||||
private readonly ISendRepository _sendRepository;
|
||||
private readonly ISendService _sendService;
|
||||
private readonly INonAnonymousSendCommand _nonAnonymousSendCommand;
|
||||
private readonly IAnonymousSendCommand _anonymousSendCommand;
|
||||
private readonly ISendAuthorizationService _sendAuthorizationService;
|
||||
private readonly ISendFileStorageService _sendFileStorageService;
|
||||
private readonly ILogger<SendsController> _logger;
|
||||
private readonly ICurrentContext _currentContext;
|
||||
@ -35,7 +39,9 @@ public class SendsControllerTests : IDisposable
|
||||
{
|
||||
_userService = Substitute.For<IUserService>();
|
||||
_sendRepository = Substitute.For<ISendRepository>();
|
||||
_sendService = Substitute.For<ISendService>();
|
||||
_nonAnonymousSendCommand = Substitute.For<INonAnonymousSendCommand>();
|
||||
_anonymousSendCommand = Substitute.For<IAnonymousSendCommand>();
|
||||
_sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
|
||||
_sendFileStorageService = Substitute.For<ISendFileStorageService>();
|
||||
_globalSettings = new GlobalSettings();
|
||||
_logger = Substitute.For<ILogger<SendsController>>();
|
||||
@ -44,7 +50,9 @@ public class SendsControllerTests : IDisposable
|
||||
_sut = new SendsController(
|
||||
_sendRepository,
|
||||
_userService,
|
||||
_sendService,
|
||||
_sendAuthorizationService,
|
||||
_anonymousSendCommand,
|
||||
_nonAnonymousSendCommand,
|
||||
_sendFileStorageService,
|
||||
_logger,
|
||||
_globalSettings,
|
||||
@ -68,7 +76,8 @@ public class SendsControllerTests : IDisposable
|
||||
send.Data = JsonSerializer.Serialize(new Dictionary<string, string>());
|
||||
send.HideEmail = true;
|
||||
|
||||
_sendService.AccessAsync(id, null).Returns((send, false, false));
|
||||
_sendRepository.GetByIdAsync(Arg.Any<Guid>()).Returns(send);
|
||||
_sendAuthorizationService.AccessAsync(send, null).Returns(SendAccessResult.Granted);
|
||||
_userService.GetUserByIdAsync(Arg.Any<Guid>()).Returns(user);
|
||||
|
||||
var request = new SendAccessRequestModel();
|
||||
|
@ -34,11 +34,11 @@ public class SendRequestModelTests
|
||||
Type = SendType.Text,
|
||||
};
|
||||
|
||||
var sendService = Substitute.For<ISendService>();
|
||||
sendService.HashPassword(Arg.Any<string>())
|
||||
var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
|
||||
sendAuthorizationService.HashPassword(Arg.Any<string>())
|
||||
.Returns((info) => $"hashed_{(string)info[0]}");
|
||||
|
||||
var send = sendRequest.ToSend(Guid.NewGuid(), sendService);
|
||||
var send = sendRequest.ToSend(Guid.NewGuid(), sendAuthorizationService);
|
||||
|
||||
Assert.Equal(deletionDate, send.DeletionDate);
|
||||
Assert.False(send.Disabled);
|
||||
|
@ -25,13 +25,13 @@ public class GetOrganizationUsersClaimedStatusQueryTests
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task GetUsersOrganizationManagementStatusAsync_WithUseSsoEnabled_Success(
|
||||
public async Task GetUsersOrganizationManagementStatusAsync_WithUseOrganizationDomainsEnabled_Success(
|
||||
Organization organization,
|
||||
ICollection<OrganizationUser> usersWithClaimedDomain,
|
||||
SutProvider<GetOrganizationUsersClaimedStatusQuery> sutProvider)
|
||||
{
|
||||
organization.Enabled = true;
|
||||
organization.UseSso = true;
|
||||
organization.UseOrganizationDomains = true;
|
||||
|
||||
var userIdWithoutClaimedDomain = Guid.NewGuid();
|
||||
var userIdsToCheck = usersWithClaimedDomain.Select(u => u.Id).Concat(new List<Guid> { userIdWithoutClaimedDomain }).ToList();
|
||||
@ -51,13 +51,13 @@ public class GetOrganizationUsersClaimedStatusQueryTests
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task GetUsersOrganizationManagementStatusAsync_WithUseSsoDisabled_ReturnsAllFalse(
|
||||
public async Task GetUsersOrganizationManagementStatusAsync_WithUseOrganizationDomainsDisabled_ReturnsAllFalse(
|
||||
Organization organization,
|
||||
ICollection<OrganizationUser> usersWithClaimedDomain,
|
||||
SutProvider<GetOrganizationUsersClaimedStatusQuery> sutProvider)
|
||||
{
|
||||
organization.Enabled = true;
|
||||
organization.UseSso = false;
|
||||
organization.UseOrganizationDomains = false;
|
||||
|
||||
var userIdWithoutClaimedDomain = Guid.NewGuid();
|
||||
var userIdsToCheck = usersWithClaimedDomain.Select(u => u.Id).Concat(new List<Guid> { userIdWithoutClaimedDomain }).ToList();
|
||||
|
@ -0,0 +1,169 @@
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Models.Business;
|
||||
using Bit.Core.Models.Data;
|
||||
using Bit.Core.Models.StaticStore;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Tools.Enums;
|
||||
using Bit.Core.Tools.Models.Business;
|
||||
using Bit.Core.Tools.Services;
|
||||
using Bit.Core.Utilities;
|
||||
using Bit.Test.Common.AutoFixture;
|
||||
using Bit.Test.Common.AutoFixture.Attributes;
|
||||
using NSubstitute;
|
||||
using Xunit;
|
||||
|
||||
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Organizations.OrganizationSignUp;
|
||||
|
||||
[SutProviderCustomize]
|
||||
public class ProviderClientOrganizationSignUpCommandTests
|
||||
{
|
||||
[Theory]
|
||||
[BitAutoData(PlanType.TeamsAnnually)]
|
||||
[BitAutoData(PlanType.TeamsMonthly)]
|
||||
[BitAutoData(PlanType.EnterpriseAnnually)]
|
||||
[BitAutoData(PlanType.EnterpriseMonthly)]
|
||||
public async Task SignupClientAsync_ValidParameters_CreatesOrganizationSuccessfully(
|
||||
PlanType planType,
|
||||
OrganizationSignup signup,
|
||||
string collectionName,
|
||||
SutProvider<ProviderClientOrganizationSignUpCommand> sutProvider)
|
||||
{
|
||||
signup.Plan = planType;
|
||||
signup.AdditionalSeats = 15;
|
||||
signup.CollectionName = collectionName;
|
||||
|
||||
var plan = StaticStore.GetPlan(signup.Plan);
|
||||
sutProvider.GetDependency<IPricingClient>()
|
||||
.GetPlanOrThrow(signup.Plan)
|
||||
.Returns(plan);
|
||||
|
||||
var result = await sutProvider.Sut.SignUpClientOrganizationAsync(signup);
|
||||
|
||||
Assert.NotNull(result.Organization);
|
||||
Assert.NotNull(result.DefaultCollection);
|
||||
Assert.Equal(collectionName, result.DefaultCollection.Name);
|
||||
|
||||
await sutProvider.GetDependency<IOrganizationRepository>()
|
||||
.Received(1)
|
||||
.CreateAsync(
|
||||
Arg.Is<Organization>(o =>
|
||||
o.Name == signup.Name &&
|
||||
o.BillingEmail == signup.BillingEmail &&
|
||||
o.PlanType == plan.Type &&
|
||||
o.Seats == signup.AdditionalSeats &&
|
||||
o.MaxCollections == plan.PasswordManager.MaxCollections &&
|
||||
o.UsePasswordManager == true &&
|
||||
o.UseSecretsManager == false &&
|
||||
o.Status == OrganizationStatusType.Created
|
||||
)
|
||||
);
|
||||
|
||||
await sutProvider.GetDependency<IReferenceEventService>()
|
||||
.Received(1)
|
||||
.RaiseEventAsync(Arg.Is<ReferenceEvent>(referenceEvent =>
|
||||
referenceEvent.Type == ReferenceEventType.Signup &&
|
||||
referenceEvent.PlanName == plan.Name &&
|
||||
referenceEvent.PlanType == plan.Type &&
|
||||
referenceEvent.Seats == result.Organization.Seats &&
|
||||
referenceEvent.Storage == result.Organization.MaxStorageGb &&
|
||||
referenceEvent.SignupInitiationPath == signup.InitiationPath
|
||||
));
|
||||
|
||||
await sutProvider.GetDependency<ICollectionRepository>()
|
||||
.Received(1)
|
||||
.CreateAsync(
|
||||
Arg.Is<Collection>(c =>
|
||||
c.Name == collectionName &&
|
||||
c.OrganizationId == result.Organization.Id
|
||||
),
|
||||
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
|
||||
Arg.Any<IEnumerable<CollectionAccessSelection>>()
|
||||
);
|
||||
|
||||
await sutProvider.GetDependency<IOrganizationApiKeyRepository>()
|
||||
.Received(1)
|
||||
.CreateAsync(
|
||||
Arg.Is<OrganizationApiKey>(k =>
|
||||
k.OrganizationId == result.Organization.Id &&
|
||||
k.Type == OrganizationApiKeyType.Default
|
||||
)
|
||||
);
|
||||
|
||||
await sutProvider.GetDependency<IApplicationCacheService>()
|
||||
.Received(1)
|
||||
.UpsertOrganizationAbilityAsync(Arg.Is<Organization>(o => o.Id == result.Organization.Id));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SignupClientAsync_NullPlan_ThrowsBadRequestException(
|
||||
OrganizationSignup signup,
|
||||
SutProvider<ProviderClientOrganizationSignUpCommand> sutProvider)
|
||||
{
|
||||
sutProvider.GetDependency<IPricingClient>()
|
||||
.GetPlanOrThrow(signup.Plan)
|
||||
.Returns((Plan)null);
|
||||
|
||||
var exception = await Assert.ThrowsAsync<BadRequestException>(
|
||||
() => sutProvider.Sut.SignUpClientOrganizationAsync(signup));
|
||||
|
||||
Assert.Contains(ProviderClientOrganizationSignUpCommand.PlanNullErrorMessage, exception.Message);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SignupClientAsync_NegativeAdditionalSeats_ThrowsBadRequestException(
|
||||
OrganizationSignup signup,
|
||||
SutProvider<ProviderClientOrganizationSignUpCommand> sutProvider)
|
||||
{
|
||||
signup.Plan = PlanType.TeamsMonthly;
|
||||
signup.AdditionalSeats = -5;
|
||||
|
||||
var plan = StaticStore.GetPlan(signup.Plan);
|
||||
sutProvider.GetDependency<IPricingClient>()
|
||||
.GetPlanOrThrow(signup.Plan)
|
||||
.Returns(plan);
|
||||
|
||||
var exception = await Assert.ThrowsAsync<BadRequestException>(
|
||||
() => sutProvider.Sut.SignUpClientOrganizationAsync(signup));
|
||||
|
||||
Assert.Contains(ProviderClientOrganizationSignUpCommand.AdditionalSeatsNegativeErrorMessage, exception.Message);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(PlanType.TeamsAnnually)]
|
||||
public async Task SignupClientAsync_WhenExceptionIsThrown_CleanupIsPerformed(
|
||||
PlanType planType,
|
||||
OrganizationSignup signup,
|
||||
SutProvider<ProviderClientOrganizationSignUpCommand> sutProvider)
|
||||
{
|
||||
signup.Plan = planType;
|
||||
|
||||
var plan = StaticStore.GetPlan(signup.Plan);
|
||||
sutProvider.GetDependency<IPricingClient>()
|
||||
.GetPlanOrThrow(signup.Plan)
|
||||
.Returns(plan);
|
||||
|
||||
sutProvider.GetDependency<IOrganizationApiKeyRepository>()
|
||||
.When(x => x.CreateAsync(Arg.Any<OrganizationApiKey>()))
|
||||
.Do(_ => throw new Exception());
|
||||
|
||||
var thrownException = await Assert.ThrowsAsync<Exception>(
|
||||
() => sutProvider.Sut.SignUpClientOrganizationAsync(signup));
|
||||
|
||||
await sutProvider.GetDependency<IOrganizationRepository>()
|
||||
.Received(1)
|
||||
.DeleteAsync(Arg.Is<Organization>(o => o.Name == signup.Name));
|
||||
|
||||
await sutProvider.GetDependency<IApplicationCacheService>()
|
||||
.Received(1)
|
||||
.DeleteOrganizationAbilityAsync(Arg.Any<Guid>());
|
||||
}
|
||||
}
|
@ -9,7 +9,6 @@ using Bit.Core.Auth.Models.Business.Tokenables;
|
||||
using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Billing.Pricing;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Models.Business;
|
||||
@ -177,47 +176,6 @@ public class OrganizationServiceTests
|
||||
referenceEvent.Users == expectedNewUsersCount));
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task SignupClientAsync_Succeeds(
|
||||
OrganizationSignup signup,
|
||||
SutProvider<OrganizationService> sutProvider)
|
||||
{
|
||||
signup.Plan = PlanType.TeamsMonthly;
|
||||
|
||||
var plan = StaticStore.GetPlan(signup.Plan);
|
||||
|
||||
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(signup.Plan).Returns(plan);
|
||||
|
||||
var (organization, _, _) = await sutProvider.Sut.SignupClientAsync(signup);
|
||||
|
||||
await sutProvider.GetDependency<IOrganizationRepository>().Received(1).CreateAsync(Arg.Is<Organization>(org =>
|
||||
org.Id == organization.Id &&
|
||||
org.Name == signup.Name &&
|
||||
org.Plan == plan.Name &&
|
||||
org.PlanType == plan.Type &&
|
||||
org.UsePolicies == plan.HasPolicies &&
|
||||
org.PublicKey == signup.PublicKey &&
|
||||
org.PrivateKey == signup.PrivateKey &&
|
||||
org.UseSecretsManager == false));
|
||||
|
||||
await sutProvider.GetDependency<IOrganizationApiKeyRepository>().Received(1)
|
||||
.CreateAsync(Arg.Is<OrganizationApiKey>(orgApiKey =>
|
||||
orgApiKey.OrganizationId == organization.Id));
|
||||
|
||||
await sutProvider.GetDependency<IApplicationCacheService>().Received(1)
|
||||
.UpsertOrganizationAbilityAsync(organization);
|
||||
|
||||
await sutProvider.GetDependency<IOrganizationUserRepository>().DidNotReceiveWithAnyArgs().CreateAsync(default);
|
||||
|
||||
await sutProvider.GetDependency<ICollectionRepository>().Received(1)
|
||||
.CreateAsync(Arg.Is<Collection>(c => c.Name == signup.CollectionName && c.OrganizationId == organization.Id), null, null);
|
||||
|
||||
await sutProvider.GetDependency<IReferenceEventService>().Received(1).RaiseEventAsync(Arg.Is<ReferenceEvent>(
|
||||
re =>
|
||||
re.Type == ReferenceEventType.Signup &&
|
||||
re.PlanType == plan.Type));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[OrganizationInviteCustomize(InviteeUserType = OrganizationUserType.User,
|
||||
InvitorUserType = OrganizationUserType.Owner), OrganizationCustomize, BitAutoData]
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -3,14 +3,11 @@ using Bit.Core.AdminConsole.Entities.Provider;
|
||||
using Bit.Core.Billing.Caches;
|
||||
using Bit.Core.Billing.Constants;
|
||||
using Bit.Core.Billing.Models;
|
||||
using Bit.Core.Billing.Services.Contracts;
|
||||
using Bit.Core.Billing.Services.Implementations;
|
||||
using Bit.Core.Billing.Tax.Models;
|
||||
using Bit.Core.Billing.Tax.Services;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Settings;
|
||||
using Bit.Core.Test.Billing.Tax.Services;
|
||||
using Bit.Test.Common.AutoFixture;
|
||||
using Bit.Test.Common.AutoFixture.Attributes;
|
||||
using Braintree;
|
||||
@ -195,7 +192,7 @@ public class SubscriberServiceTests
|
||||
|
||||
await stripeAdapter
|
||||
.DidNotReceiveWithAnyArgs()
|
||||
.SubscriptionCancelAsync(Arg.Any<string>(), Arg.Any<SubscriptionCancelOptions>()); ;
|
||||
.SubscriptionCancelAsync(Arg.Any<string>(), Arg.Any<SubscriptionCancelOptions>());
|
||||
}
|
||||
|
||||
#endregion
|
||||
@ -1029,7 +1026,7 @@ public class SubscriberServiceTests
|
||||
|
||||
stripeAdapter
|
||||
.PaymentMethodListAutoPagingAsync(Arg.Any<PaymentMethodListOptions>())
|
||||
.Returns(GetPaymentMethodsAsync(new List<Stripe.PaymentMethod>()));
|
||||
.Returns(GetPaymentMethodsAsync(new List<PaymentMethod>()));
|
||||
|
||||
await sutProvider.Sut.RemovePaymentSource(organization);
|
||||
|
||||
@ -1061,7 +1058,7 @@ public class SubscriberServiceTests
|
||||
|
||||
stripeAdapter
|
||||
.PaymentMethodListAutoPagingAsync(Arg.Any<PaymentMethodListOptions>())
|
||||
.Returns(GetPaymentMethodsAsync(new List<Stripe.PaymentMethod>
|
||||
.Returns(GetPaymentMethodsAsync(new List<PaymentMethod>
|
||||
{
|
||||
new ()
|
||||
{
|
||||
@ -1086,8 +1083,8 @@ public class SubscriberServiceTests
|
||||
.PaymentMethodDetachAsync(cardId);
|
||||
}
|
||||
|
||||
private static async IAsyncEnumerable<Stripe.PaymentMethod> GetPaymentMethodsAsync(
|
||||
IEnumerable<Stripe.PaymentMethod> paymentMethods)
|
||||
private static async IAsyncEnumerable<PaymentMethod> GetPaymentMethodsAsync(
|
||||
IEnumerable<PaymentMethod> paymentMethods)
|
||||
{
|
||||
foreach (var paymentMethod in paymentMethods)
|
||||
{
|
||||
@ -1598,14 +1595,22 @@ public class SubscriberServiceTests
|
||||
City = "Example Town",
|
||||
State = "NY"
|
||||
},
|
||||
TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] }
|
||||
TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] },
|
||||
Subscriptions = new StripeList<Subscription>
|
||||
{
|
||||
Data = [
|
||||
new Subscription
|
||||
{
|
||||
Id = provider.GatewaySubscriptionId,
|
||||
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
var subscription = new Subscription { Items = new StripeList<SubscriptionItem>() };
|
||||
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(Arg.Any<string>())
|
||||
.Returns(subscription);
|
||||
sutProvider.GetDependency<IAutomaticTaxFactory>().CreateAsync(Arg.Any<AutomaticTaxFactoryParameters>())
|
||||
.Returns(new FakeAutomaticTaxStrategy(true));
|
||||
|
||||
await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation);
|
||||
|
||||
@ -1623,6 +1628,98 @@ public class SubscriberServiceTests
|
||||
await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is<TaxIdCreateOptions>(
|
||||
options => options.Type == "us_ein" &&
|
||||
options.Value == taxInformation.TaxId));
|
||||
|
||||
await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
|
||||
Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true));
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task UpdateTaxInformation_NonUser_ReverseCharge_MakesCorrectInvocations(
|
||||
Provider provider,
|
||||
SutProvider<SubscriberService> sutProvider)
|
||||
{
|
||||
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
|
||||
|
||||
var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } };
|
||||
|
||||
stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is<CustomerGetOptions>(
|
||||
options => options.Expand.Contains("tax_ids"))).Returns(customer);
|
||||
|
||||
var taxInformation = new TaxInformation(
|
||||
"CA",
|
||||
"12345",
|
||||
"123456789",
|
||||
"us_ein",
|
||||
"123 Example St.",
|
||||
null,
|
||||
"Example Town",
|
||||
"NY");
|
||||
|
||||
sutProvider.GetDependency<IStripeAdapter>()
|
||||
.CustomerUpdateAsync(
|
||||
Arg.Is<string>(p => p == provider.GatewayCustomerId),
|
||||
Arg.Is<CustomerUpdateOptions>(options =>
|
||||
options.Address.Country == "CA" &&
|
||||
options.Address.PostalCode == "12345" &&
|
||||
options.Address.Line1 == "123 Example St." &&
|
||||
options.Address.Line2 == null &&
|
||||
options.Address.City == "Example Town" &&
|
||||
options.Address.State == "NY"))
|
||||
.Returns(new Customer
|
||||
{
|
||||
Id = provider.GatewayCustomerId,
|
||||
Address = new Address
|
||||
{
|
||||
Country = "CA",
|
||||
PostalCode = "12345",
|
||||
Line1 = "123 Example St.",
|
||||
Line2 = null,
|
||||
City = "Example Town",
|
||||
State = "NY"
|
||||
},
|
||||
TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] },
|
||||
Subscriptions = new StripeList<Subscription>
|
||||
{
|
||||
Data = [
|
||||
new Subscription
|
||||
{
|
||||
Id = provider.GatewaySubscriptionId,
|
||||
CustomerId = provider.GatewayCustomerId,
|
||||
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
var subscription = new Subscription { Items = new StripeList<SubscriptionItem>() };
|
||||
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(Arg.Any<string>())
|
||||
.Returns(subscription);
|
||||
|
||||
sutProvider.GetDependency<IFeatureService>()
|
||||
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
|
||||
|
||||
await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation);
|
||||
|
||||
await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(
|
||||
options =>
|
||||
options.Address.Country == taxInformation.Country &&
|
||||
options.Address.PostalCode == taxInformation.PostalCode &&
|
||||
options.Address.Line1 == taxInformation.Line1 &&
|
||||
options.Address.Line2 == taxInformation.Line2 &&
|
||||
options.Address.City == taxInformation.City &&
|
||||
options.Address.State == taxInformation.State));
|
||||
|
||||
await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1");
|
||||
|
||||
await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is<TaxIdCreateOptions>(
|
||||
options => options.Type == "us_ein" &&
|
||||
options.Value == taxInformation.TaxId));
|
||||
|
||||
await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId,
|
||||
Arg.Is<CustomerUpdateOptions>(options => options.TaxExempt == StripeConstants.TaxExempt.Reverse));
|
||||
|
||||
await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
|
||||
Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
@ -28,7 +28,10 @@ public static class OrganizationLicenseFileFixtures
|
||||
private const string Version15 =
|
||||
"{\n 'LicenseKey': 'myLicenseKey',\n 'InstallationId': '78900000-0000-0000-0000-000000000123',\n 'Id': '12300000-0000-0000-0000-000000000456',\n 'Name': 'myOrg',\n 'BillingEmail': 'myBillingEmail',\n 'BusinessName': 'myBusinessName',\n 'Enabled': true,\n 'Plan': 'myPlan',\n 'PlanType': 11,\n 'Seats': 10,\n 'MaxCollections': 2,\n 'UsePolicies': true,\n 'UseSso': true,\n 'UseKeyConnector': true,\n 'UseScim': true,\n 'UseGroups': true,\n 'UseEvents': true,\n 'UseDirectory': true,\n 'UseTotp': true,\n 'Use2fa': true,\n 'UseApi': true,\n 'UseResetPassword': true,\n 'MaxStorageGb': 100,\n 'SelfHost': true,\n 'UsersGetPremium': true,\n 'UseCustomPermissions': true,\n 'Version': 14,\n 'Issued': '2023-12-14T02:03:33.374297Z',\n 'Refresh': '2023-12-07T22:42:33.970597Z',\n 'Expires': '2023-12-21T02:03:33.374297Z',\n 'ExpirationWithoutGracePeriod': null,\n 'UsePasswordManager': true,\n 'UseSecretsManager': true,\n 'SmSeats': 5,\n 'SmServiceAccounts': 8,\n 'LimitCollectionCreationDeletion': true,\n 'AllowAdminAccessToAllCollectionItems': true,\n 'Trial': true,\n 'LicenseType': 1,\n 'Hash': 'EZl4IvJaa1E5mPmlfp4p5twAtlmaxlF1yoZzVYP4vog=',\n 'Signature': ''\n}";
|
||||
|
||||
private static readonly Dictionary<int, string> LicenseVersions = new() { { 12, Version12 }, { 13, Version13 }, { 14, Version14 }, { 15, Version15 } };
|
||||
private const string Version16 =
|
||||
"{\n'LicenseKey': 'myLicenseKey',\n'InstallationId': '78900000-0000-0000-0000-000000000123',\n'Id': '12300000-0000-0000-0000-000000000456',\n'Name': 'myOrg',\n'BillingEmail': 'myBillingEmail',\n'BusinessName': 'myBusinessName',\n'Enabled': true,\n'Plan': 'myPlan',\n'PlanType': 11,\n'Seats': 10,\n'MaxCollections': 2,\n'UsePolicies': true,\n'UseSso': true,\n'UseKeyConnector': true,\n'UseScim': true,\n'UseGroups': true,\n'UseEvents': true,\n'UseDirectory': true,\n'UseTotp': true,\n'Use2fa': true,\n'UseApi': true,\n'UseResetPassword': true,\n'MaxStorageGb': 100,\n'SelfHost': true,\n'UsersGetPremium': true,\n'UseCustomPermissions': true,\n'Version': 15,\n'Issued': '2025-05-16T20:50:09.036931Z',\n'Refresh': '2025-05-23T20:50:09.036931Z',\n'Expires': '2025-05-23T20:50:09.036931Z',\n'ExpirationWithoutGracePeriod': null,\n'UsePasswordManager': true,\n'UseSecretsManager': true,\n'SmSeats': 5,\n'SmServiceAccounts': 8,\n'UseRiskInsights': false,\n'LimitCollectionCreationDeletion': true,\n'AllowAdminAccessToAllCollectionItems': true,\n'Trial': true,\n'LicenseType': 1,\n'UseOrganizationDomains': true,\n'UseAdminSponsoredFamilies': false,\n'Hash': 'k3M9SpHKUo0TmuSnNipeZleCHxgcEycKRXYl9BAg30Q=',\n'Signature': '',\n'Token': null\n}";
|
||||
|
||||
private static readonly Dictionary<int, string> LicenseVersions = new() { { 12, Version12 }, { 13, Version13 }, { 14, Version14 }, { 15, Version15 }, { 16, Version16 } };
|
||||
|
||||
public static OrganizationLicense GetVersion(int licenseVersion)
|
||||
{
|
||||
|
@ -347,7 +347,7 @@ public class UserServiceTests
|
||||
SutProvider<UserService> sutProvider, Guid userId, Organization organization)
|
||||
{
|
||||
organization.Enabled = true;
|
||||
organization.UseSso = true;
|
||||
organization.UseOrganizationDomains = true;
|
||||
|
||||
sutProvider.GetDependency<IOrganizationRepository>()
|
||||
.GetByVerifiedUserEmailDomainAsync(userId)
|
||||
@ -362,7 +362,7 @@ public class UserServiceTests
|
||||
SutProvider<UserService> sutProvider, Guid userId, Organization organization)
|
||||
{
|
||||
organization.Enabled = false;
|
||||
organization.UseSso = true;
|
||||
organization.UseOrganizationDomains = true;
|
||||
|
||||
sutProvider.GetDependency<IOrganizationRepository>()
|
||||
.GetByVerifiedUserEmailDomainAsync(userId)
|
||||
@ -373,11 +373,11 @@ public class UserServiceTests
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task IsClaimedByAnyOrganizationAsync_WithOrganizationUseSsoFalse_ReturnsFalse(
|
||||
public async Task IsClaimedByAnyOrganizationAsync_WithOrganizationUseOrganizationDomaisFalse_ReturnsFalse(
|
||||
SutProvider<UserService> sutProvider, Guid userId, Organization organization)
|
||||
{
|
||||
organization.Enabled = true;
|
||||
organization.UseSso = false;
|
||||
organization.UseOrganizationDomains = false;
|
||||
|
||||
sutProvider.GetDependency<IOrganizationRepository>()
|
||||
.GetByVerifiedUserEmailDomainAsync(userId)
|
||||
|
118
test/Core.Test/Tools/Services/AnonymousSendCommandTests.cs
Normal file
118
test/Core.Test/Tools/Services/AnonymousSendCommandTests.cs
Normal file
@ -0,0 +1,118 @@
|
||||
using System.Text.Json;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Enums;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Bit.Core.Tools.SendFeatures.Commands;
|
||||
using Bit.Core.Tools.Services;
|
||||
using NSubstitute;
|
||||
using Xunit;
|
||||
|
||||
namespace Bit.Core.Test.Tools.Services;
|
||||
|
||||
public class AnonymousSendCommandTests
|
||||
{
|
||||
private readonly ISendRepository _sendRepository;
|
||||
private readonly ISendFileStorageService _sendFileStorageService;
|
||||
private readonly IPushNotificationService _pushNotificationService;
|
||||
private readonly ISendAuthorizationService _sendAuthorizationService;
|
||||
private readonly AnonymousSendCommand _anonymousSendCommand;
|
||||
|
||||
public AnonymousSendCommandTests()
|
||||
{
|
||||
_sendRepository = Substitute.For<ISendRepository>();
|
||||
_sendFileStorageService = Substitute.For<ISendFileStorageService>();
|
||||
_pushNotificationService = Substitute.For<IPushNotificationService>();
|
||||
_sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
|
||||
|
||||
_anonymousSendCommand = new AnonymousSendCommand(
|
||||
_sendRepository,
|
||||
_sendFileStorageService,
|
||||
_pushNotificationService,
|
||||
_sendAuthorizationService);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GetSendFileDownloadUrlAsync_Success_ReturnsDownloadUrl()
|
||||
{
|
||||
// Arrange
|
||||
var send = new Send
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
Type = SendType.File,
|
||||
AccessCount = 0,
|
||||
Data = JsonSerializer.Serialize(new { Id = "fileId123" })
|
||||
};
|
||||
var fileId = "fileId123";
|
||||
var password = "testPassword";
|
||||
var expectedUrl = "https://example.com/download";
|
||||
|
||||
_sendAuthorizationService
|
||||
.SendCanBeAccessed(send, password)
|
||||
.Returns(SendAccessResult.Granted);
|
||||
|
||||
_sendFileStorageService
|
||||
.GetSendFileDownloadUrlAsync(send, fileId)
|
||||
.Returns(expectedUrl);
|
||||
|
||||
// Act
|
||||
var result =
|
||||
await _anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId, password);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(expectedUrl, result.Item1);
|
||||
Assert.Equal(1, send.AccessCount);
|
||||
|
||||
await _sendRepository.Received(1).ReplaceAsync(send);
|
||||
await _pushNotificationService.Received(1).PushSyncSendUpdateAsync(send);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GetSendFileDownloadUrlAsync_AccessDenied_ReturnsNullWithReasons()
|
||||
{
|
||||
// Arrange
|
||||
var send = new Send
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
Type = SendType.File,
|
||||
AccessCount = 0
|
||||
};
|
||||
var fileId = "fileId123";
|
||||
var password = "wrongPassword";
|
||||
|
||||
_sendAuthorizationService
|
||||
.SendCanBeAccessed(send, password)
|
||||
.Returns(SendAccessResult.Denied);
|
||||
|
||||
// Act
|
||||
var result =
|
||||
await _anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId, password);
|
||||
|
||||
// Assert
|
||||
Assert.Null(result.Item1);
|
||||
Assert.Equal(SendAccessResult.Denied, result.Item2);
|
||||
Assert.Equal(0, send.AccessCount);
|
||||
|
||||
await _sendRepository.DidNotReceiveWithAnyArgs().ReplaceAsync(default);
|
||||
await _pushNotificationService.DidNotReceiveWithAnyArgs().PushSyncSendUpdateAsync(default);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GetSendFileDownloadUrlAsync_NotFileSend_ThrowsBadRequestException()
|
||||
{
|
||||
// Arrange
|
||||
var send = new Send
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
Type = SendType.Text
|
||||
};
|
||||
var fileId = "fileId123";
|
||||
var password = "testPassword";
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
_anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId, password));
|
||||
}
|
||||
}
|
1111
test/Core.Test/Tools/Services/NonAnonymousSendCommandTests.cs
Normal file
1111
test/Core.Test/Tools/Services/NonAnonymousSendCommandTests.cs
Normal file
File diff suppressed because it is too large
Load Diff
175
test/Core.Test/Tools/Services/SendAuthorizationServiceTests.cs
Normal file
175
test/Core.Test/Tools/Services/SendAuthorizationServiceTests.cs
Normal file
@ -0,0 +1,175 @@
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Bit.Core.Tools.Services;
|
||||
using Microsoft.AspNetCore.Identity;
|
||||
using NSubstitute;
|
||||
using Xunit;
|
||||
|
||||
namespace Bit.Core.Test.Tools.Services;
|
||||
|
||||
public class SendAuthorizationServiceTests
|
||||
{
|
||||
private readonly ISendRepository _sendRepository;
|
||||
private readonly IPasswordHasher<Bit.Core.Entities.User> _passwordHasher;
|
||||
private readonly IPushNotificationService _pushNotificationService;
|
||||
private readonly IReferenceEventService _referenceEventService;
|
||||
private readonly ICurrentContext _currentContext;
|
||||
private readonly SendAuthorizationService _sendAuthorizationService;
|
||||
|
||||
public SendAuthorizationServiceTests()
|
||||
{
|
||||
_sendRepository = Substitute.For<ISendRepository>();
|
||||
_passwordHasher = Substitute.For<IPasswordHasher<Bit.Core.Entities.User>>();
|
||||
_pushNotificationService = Substitute.For<IPushNotificationService>();
|
||||
_referenceEventService = Substitute.For<IReferenceEventService>();
|
||||
_currentContext = Substitute.For<ICurrentContext>();
|
||||
|
||||
_sendAuthorizationService = new SendAuthorizationService(
|
||||
_sendRepository,
|
||||
_passwordHasher,
|
||||
_pushNotificationService,
|
||||
_referenceEventService,
|
||||
_currentContext);
|
||||
}
|
||||
|
||||
|
||||
[Fact]
|
||||
public void SendCanBeAccessed_Success_ReturnsTrue()
|
||||
{
|
||||
// Arrange
|
||||
var send = new Send
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
UserId = Guid.NewGuid(),
|
||||
MaxAccessCount = 10,
|
||||
AccessCount = 5,
|
||||
ExpirationDate = DateTime.UtcNow.AddYears(1),
|
||||
DeletionDate = DateTime.UtcNow.AddYears(1),
|
||||
Disabled = false,
|
||||
Password = "hashedPassword123"
|
||||
};
|
||||
|
||||
const string password = "TEST";
|
||||
|
||||
_passwordHasher
|
||||
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), send.Password, password)
|
||||
.Returns(PasswordVerificationResult.Success);
|
||||
|
||||
// Act
|
||||
var result =
|
||||
_sendAuthorizationService.SendCanBeAccessed(send, password);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(SendAccessResult.Granted, result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SendCanBeAccessed_NullMaxAccess_Success()
|
||||
{
|
||||
// Arrange
|
||||
var send = new Send
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
UserId = Guid.NewGuid(),
|
||||
MaxAccessCount = null,
|
||||
AccessCount = 5,
|
||||
ExpirationDate = DateTime.UtcNow.AddYears(1),
|
||||
DeletionDate = DateTime.UtcNow.AddYears(1),
|
||||
Disabled = false,
|
||||
Password = "hashedPassword123"
|
||||
};
|
||||
|
||||
const string password = "TEST";
|
||||
|
||||
_passwordHasher
|
||||
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), send.Password, password)
|
||||
.Returns(PasswordVerificationResult.Success);
|
||||
|
||||
// Act
|
||||
var result = _sendAuthorizationService.SendCanBeAccessed(send, password);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(SendAccessResult.Granted, result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SendCanBeAccessed_NullSend_DoesNotGrantAccess()
|
||||
{
|
||||
// Arrange
|
||||
_passwordHasher
|
||||
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST", "TEST")
|
||||
.Returns(PasswordVerificationResult.Success);
|
||||
|
||||
// Act
|
||||
var result =
|
||||
_sendAuthorizationService.SendCanBeAccessed(null, "TEST");
|
||||
|
||||
// Assert
|
||||
Assert.Equal(SendAccessResult.Denied, result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SendCanBeAccessed_RehashNeeded_RehashesPassword()
|
||||
{
|
||||
// Arrange
|
||||
var now = DateTime.UtcNow;
|
||||
var send = new Send
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
UserId = Guid.NewGuid(),
|
||||
MaxAccessCount = null,
|
||||
AccessCount = 5,
|
||||
ExpirationDate = now.AddYears(1),
|
||||
DeletionDate = now.AddYears(1),
|
||||
Disabled = false,
|
||||
Password = "TEST"
|
||||
};
|
||||
|
||||
_passwordHasher
|
||||
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST", "TEST")
|
||||
.Returns(PasswordVerificationResult.SuccessRehashNeeded);
|
||||
|
||||
// Act
|
||||
var result =
|
||||
_sendAuthorizationService.SendCanBeAccessed(send, "TEST");
|
||||
|
||||
// Assert
|
||||
_passwordHasher
|
||||
.Received(1)
|
||||
.HashPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST");
|
||||
|
||||
Assert.Equal(SendAccessResult.Granted, result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SendCanBeAccessed_VerifyFailed_PasswordInvalidReturnsTrue()
|
||||
{
|
||||
// Arrange
|
||||
var now = DateTime.UtcNow;
|
||||
var send = new Send
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
UserId = Guid.NewGuid(),
|
||||
MaxAccessCount = null,
|
||||
AccessCount = 5,
|
||||
ExpirationDate = now.AddYears(1),
|
||||
DeletionDate = now.AddYears(1),
|
||||
Disabled = false,
|
||||
Password = "TEST"
|
||||
};
|
||||
|
||||
_passwordHasher
|
||||
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST", "TEST")
|
||||
.Returns(PasswordVerificationResult.Failed);
|
||||
|
||||
// Act
|
||||
var result =
|
||||
_sendAuthorizationService.SendCanBeAccessed(send, "TEST");
|
||||
|
||||
// Assert
|
||||
Assert.Equal(SendAccessResult.PasswordInvalid, result);
|
||||
}
|
||||
}
|
@ -1,867 +0,0 @@
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using Bit.Core.AdminConsole.Entities;
|
||||
using Bit.Core.AdminConsole.Enums;
|
||||
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
|
||||
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
|
||||
using Bit.Core.AdminConsole.Services;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Test.AutoFixture.CurrentContextFixtures;
|
||||
using Bit.Core.Test.Entities;
|
||||
using Bit.Core.Test.Tools.AutoFixture.SendFixtures;
|
||||
using Bit.Core.Tools.Entities;
|
||||
using Bit.Core.Tools.Enums;
|
||||
using Bit.Core.Tools.Models.Data;
|
||||
using Bit.Core.Tools.Repositories;
|
||||
using Bit.Core.Tools.Services;
|
||||
using Bit.Test.Common.AutoFixture;
|
||||
using Bit.Test.Common.AutoFixture.Attributes;
|
||||
using Microsoft.AspNetCore.Identity;
|
||||
using NSubstitute;
|
||||
using NSubstitute.ExceptionExtensions;
|
||||
using Xunit;
|
||||
|
||||
using GlobalSettings = Bit.Core.Settings.GlobalSettings;
|
||||
|
||||
namespace Bit.Core.Test.Tools.Services;
|
||||
|
||||
[SutProviderCustomize]
|
||||
[CurrentContextCustomize]
|
||||
[UserSendCustomize]
|
||||
public class SendServiceTests
|
||||
{
|
||||
private void SaveSendAsync_Setup(SendType sendType, bool disableSendPolicyAppliesToUser,
|
||||
SutProvider<SendService> sutProvider, Send send)
|
||||
{
|
||||
send.Id = default;
|
||||
send.Type = sendType;
|
||||
|
||||
sutProvider.GetDependency<IPolicyService>().AnyPoliciesApplicableToUserAsync(
|
||||
Arg.Any<Guid>(), PolicyType.DisableSend).Returns(disableSendPolicyAppliesToUser);
|
||||
}
|
||||
|
||||
// Disable Send policy check
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(SendType.File)]
|
||||
[BitAutoData(SendType.Text)]
|
||||
public async Task SaveSendAsync_DisableSend_Applies_throws(SendType sendType,
|
||||
SutProvider<SendService> sutProvider, Send send)
|
||||
{
|
||||
SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: true, sutProvider, send);
|
||||
|
||||
await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(SendType.File)]
|
||||
[BitAutoData(SendType.Text)]
|
||||
public async Task SaveSendAsync_DisableSend_DoesntApply_success(SendType sendType,
|
||||
SutProvider<SendService> sutProvider, Send send)
|
||||
{
|
||||
SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: false, sutProvider, send);
|
||||
|
||||
await sutProvider.Sut.SaveSendAsync(send);
|
||||
|
||||
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
|
||||
}
|
||||
|
||||
// Send Options Policy - Disable Hide Email check
|
||||
|
||||
private void SaveSendAsync_HideEmail_Setup(bool disableHideEmailAppliesToUser,
|
||||
SutProvider<SendService> sutProvider, Send send, Policy policy)
|
||||
{
|
||||
send.HideEmail = true;
|
||||
|
||||
var sendOptions = new SendOptionsPolicyData
|
||||
{
|
||||
DisableHideEmail = disableHideEmailAppliesToUser
|
||||
};
|
||||
policy.Data = JsonSerializer.Serialize(sendOptions, new JsonSerializerOptions
|
||||
{
|
||||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
|
||||
});
|
||||
|
||||
sutProvider.GetDependency<IPolicyService>().GetPoliciesApplicableToUserAsync(
|
||||
Arg.Any<Guid>(), PolicyType.SendOptions).Returns(new List<OrganizationUserPolicyDetails>()
|
||||
{
|
||||
new() { PolicyType = policy.Type, PolicyData = policy.Data, OrganizationId = policy.OrganizationId, PolicyEnabled = policy.Enabled }
|
||||
});
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(SendType.File)]
|
||||
[BitAutoData(SendType.Text)]
|
||||
public async Task SaveSendAsync_DisableHideEmail_Applies_throws(SendType sendType,
|
||||
SutProvider<SendService> sutProvider, Send send, Policy policy)
|
||||
{
|
||||
SaveSendAsync_Setup(sendType, false, sutProvider, send);
|
||||
SaveSendAsync_HideEmail_Setup(true, sutProvider, send, policy);
|
||||
|
||||
await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(SendType.File)]
|
||||
[BitAutoData(SendType.Text)]
|
||||
public async Task SaveSendAsync_DisableHideEmail_DoesntApply_success(SendType sendType,
|
||||
SutProvider<SendService> sutProvider, Send send, Policy policy)
|
||||
{
|
||||
SaveSendAsync_Setup(sendType, false, sutProvider, send);
|
||||
SaveSendAsync_HideEmail_Setup(false, sutProvider, send, policy);
|
||||
|
||||
await sutProvider.Sut.SaveSendAsync(send);
|
||||
|
||||
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
|
||||
}
|
||||
|
||||
// Disable Send policy check - vNext
|
||||
private void SaveSendAsync_Setup_vNext(SutProvider<SendService> sutProvider, Send send,
|
||||
DisableSendPolicyRequirement disableSendPolicyRequirement, SendOptionsPolicyRequirement sendOptionsPolicyRequirement)
|
||||
{
|
||||
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<DisableSendPolicyRequirement>(send.UserId!.Value)
|
||||
.Returns(disableSendPolicyRequirement);
|
||||
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<SendOptionsPolicyRequirement>(send.UserId!.Value)
|
||||
.Returns(sendOptionsPolicyRequirement);
|
||||
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
|
||||
|
||||
// Should not be called in these tests
|
||||
sutProvider.GetDependency<IPolicyService>().AnyPoliciesApplicableToUserAsync(
|
||||
Arg.Any<Guid>(), Arg.Any<PolicyType>()).ThrowsAsync<Exception>();
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(SendType.File)]
|
||||
[BitAutoData(SendType.Text)]
|
||||
public async Task SaveSendAsync_DisableSend_Applies_Throws_vNext(SendType sendType,
|
||||
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
|
||||
{
|
||||
send.Type = sendType;
|
||||
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement { DisableSend = true }, new SendOptionsPolicyRequirement());
|
||||
|
||||
var exception = await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
|
||||
Assert.Contains("Due to an Enterprise Policy, you are only able to delete an existing Send.",
|
||||
exception.Message);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(SendType.File)]
|
||||
[BitAutoData(SendType.Text)]
|
||||
public async Task SaveSendAsync_DisableSend_DoesntApply_Success_vNext(SendType sendType,
|
||||
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
|
||||
{
|
||||
send.Type = sendType;
|
||||
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement());
|
||||
|
||||
await sutProvider.Sut.SaveSendAsync(send);
|
||||
|
||||
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
|
||||
}
|
||||
|
||||
// Send Options Policy - Disable Hide Email check
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(SendType.File)]
|
||||
[BitAutoData(SendType.Text)]
|
||||
public async Task SaveSendAsync_DisableHideEmail_Applies_Throws_vNext(SendType sendType,
|
||||
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
|
||||
{
|
||||
send.Type = sendType;
|
||||
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement { DisableHideEmail = true });
|
||||
send.HideEmail = true;
|
||||
|
||||
var exception = await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
|
||||
Assert.Contains("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.", exception.Message);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(SendType.File)]
|
||||
[BitAutoData(SendType.Text)]
|
||||
public async Task SaveSendAsync_DisableHideEmail_Applies_ButEmailNotHidden_Success_vNext(SendType sendType,
|
||||
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
|
||||
{
|
||||
send.Type = sendType;
|
||||
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement { DisableHideEmail = true });
|
||||
send.HideEmail = false;
|
||||
|
||||
await sutProvider.Sut.SaveSendAsync(send);
|
||||
|
||||
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData(SendType.File)]
|
||||
[BitAutoData(SendType.Text)]
|
||||
public async Task SaveSendAsync_DisableHideEmail_DoesntApply_Success_vNext(SendType sendType,
|
||||
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
|
||||
{
|
||||
send.Type = sendType;
|
||||
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement());
|
||||
send.HideEmail = true;
|
||||
|
||||
await sutProvider.Sut.SaveSendAsync(send);
|
||||
|
||||
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveSendAsync_ExistingSend_Updates(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
send.Id = Guid.NewGuid();
|
||||
|
||||
var now = DateTime.UtcNow;
|
||||
await sutProvider.Sut.SaveSendAsync(send);
|
||||
|
||||
Assert.True(send.RevisionDate - now < TimeSpan.FromSeconds(1));
|
||||
|
||||
await sutProvider.GetDependency<ISendRepository>()
|
||||
.Received(1)
|
||||
.UpsertAsync(send);
|
||||
|
||||
await sutProvider.GetDependency<IPushNotificationService>()
|
||||
.Received(1)
|
||||
.PushSyncSendUpdateAsync(send);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_TextType_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
send.Type = SendType.Text;
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 0)
|
||||
);
|
||||
|
||||
Assert.Contains("not of type \"file\"", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_EmptyFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
send.Type = SendType.File;
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 0)
|
||||
);
|
||||
|
||||
Assert.Contains("no file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_UserCannotAccessPremium_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var user = new User
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
};
|
||||
|
||||
send.UserId = user.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetByIdAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
sutProvider.GetDependency<IUserService>()
|
||||
.CanAccessPremium(user)
|
||||
.Returns(false);
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
|
||||
);
|
||||
|
||||
Assert.Contains("must have premium", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_UserHasUnconfirmedEmail_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var user = new User
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
EmailVerified = false,
|
||||
};
|
||||
|
||||
send.UserId = user.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetByIdAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
sutProvider.GetDependency<IUserService>()
|
||||
.CanAccessPremium(user)
|
||||
.Returns(true);
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
|
||||
);
|
||||
|
||||
Assert.Contains("must confirm your email", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_UserCanAccessPremium_HasNoStorage_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var user = new User
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
EmailVerified = true,
|
||||
Premium = true,
|
||||
MaxStorageGb = null,
|
||||
Storage = 0,
|
||||
};
|
||||
|
||||
send.UserId = user.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetByIdAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
sutProvider.GetDependency<IUserService>()
|
||||
.CanAccessPremium(user)
|
||||
.Returns(true);
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
|
||||
);
|
||||
|
||||
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_UserCanAccessPremium_StorageFull_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var user = new User
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
EmailVerified = true,
|
||||
Premium = true,
|
||||
MaxStorageGb = 2,
|
||||
Storage = 2 * UserTests.Multiplier,
|
||||
};
|
||||
|
||||
send.UserId = user.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetByIdAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
sutProvider.GetDependency<IUserService>()
|
||||
.CanAccessPremium(user)
|
||||
.Returns(true);
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
|
||||
);
|
||||
|
||||
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsSelfHosted_GiantFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var user = new User
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
EmailVerified = true,
|
||||
Premium = false,
|
||||
};
|
||||
|
||||
send.UserId = user.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetByIdAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
sutProvider.GetDependency<IUserService>()
|
||||
.CanAccessPremium(user)
|
||||
.Returns(true);
|
||||
|
||||
sutProvider.GetDependency<GlobalSettings>()
|
||||
.SelfHosted = true;
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 11000 * UserTests.Multiplier)
|
||||
);
|
||||
|
||||
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsNotSelfHosted_TwoGigabyteFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var user = new User
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
EmailVerified = true,
|
||||
Premium = false,
|
||||
};
|
||||
|
||||
send.UserId = user.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetByIdAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
sutProvider.GetDependency<IUserService>()
|
||||
.CanAccessPremium(user)
|
||||
.Returns(true);
|
||||
|
||||
sutProvider.GetDependency<GlobalSettings>()
|
||||
.SelfHosted = false;
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier)
|
||||
);
|
||||
|
||||
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var org = new Organization
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
MaxStorageGb = null,
|
||||
};
|
||||
|
||||
send.UserId = null;
|
||||
send.OrganizationId = org.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
sutProvider.GetDependency<IOrganizationRepository>()
|
||||
.GetByIdAsync(org.Id)
|
||||
.Returns(org);
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
|
||||
);
|
||||
|
||||
Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_TwoGBFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var org = new Organization
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
MaxStorageGb = null,
|
||||
};
|
||||
|
||||
send.UserId = null;
|
||||
send.OrganizationId = org.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
sutProvider.GetDependency<IOrganizationRepository>()
|
||||
.GetByIdAsync(org.Id)
|
||||
.Returns(org);
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
|
||||
);
|
||||
|
||||
Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_ThroughOrg_MaxStorageIsOneGB_TwoGBFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var org = new Organization
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
MaxStorageGb = 1,
|
||||
};
|
||||
|
||||
send.UserId = null;
|
||||
send.OrganizationId = org.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
sutProvider.GetDependency<IOrganizationRepository>()
|
||||
.GetByIdAsync(org.Id)
|
||||
.Returns(org);
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier)
|
||||
);
|
||||
|
||||
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_HasEnoughStorage_Success(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var user = new User
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
EmailVerified = true,
|
||||
MaxStorageGb = 10,
|
||||
};
|
||||
|
||||
var data = new SendFileData
|
||||
{
|
||||
|
||||
};
|
||||
|
||||
send.UserId = user.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
var testUrl = "https://test.com/";
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetByIdAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
sutProvider.GetDependency<IUserService>()
|
||||
.CanAccessPremium(user)
|
||||
.Returns(true);
|
||||
|
||||
sutProvider.GetDependency<ISendFileStorageService>()
|
||||
.GetSendFileUploadUrlAsync(send, Arg.Any<string>())
|
||||
.Returns(testUrl);
|
||||
|
||||
var utcNow = DateTime.UtcNow;
|
||||
|
||||
var url = await sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier);
|
||||
|
||||
Assert.Equal(testUrl, url);
|
||||
Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1));
|
||||
|
||||
await sutProvider.GetDependency<ISendFileStorageService>()
|
||||
.Received(1)
|
||||
.GetSendFileUploadUrlAsync(send, Arg.Any<string>());
|
||||
|
||||
await sutProvider.GetDependency<ISendRepository>()
|
||||
.Received(1)
|
||||
.UpsertAsync(send);
|
||||
|
||||
await sutProvider.GetDependency<IPushNotificationService>()
|
||||
.Received(1)
|
||||
.PushSyncSendUpdateAsync(send);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SaveFileSendAsync_HasEnoughStorage_SendFileThrows_CleansUp(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var user = new User
|
||||
{
|
||||
Id = Guid.NewGuid(),
|
||||
EmailVerified = true,
|
||||
MaxStorageGb = 10,
|
||||
};
|
||||
|
||||
var data = new SendFileData
|
||||
{
|
||||
|
||||
};
|
||||
|
||||
send.UserId = user.Id;
|
||||
send.Type = SendType.File;
|
||||
|
||||
sutProvider.GetDependency<IUserRepository>()
|
||||
.GetByIdAsync(user.Id)
|
||||
.Returns(user);
|
||||
|
||||
sutProvider.GetDependency<IUserService>()
|
||||
.CanAccessPremium(user)
|
||||
.Returns(true);
|
||||
|
||||
sutProvider.GetDependency<ISendFileStorageService>()
|
||||
.GetSendFileUploadUrlAsync(send, Arg.Any<string>())
|
||||
.Returns<string>(callInfo => throw new Exception("Problem"));
|
||||
|
||||
var utcNow = DateTime.UtcNow;
|
||||
|
||||
var exception = await Assert.ThrowsAsync<Exception>(() =>
|
||||
sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier)
|
||||
);
|
||||
|
||||
Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1));
|
||||
Assert.Equal("Problem", exception.Message);
|
||||
|
||||
await sutProvider.GetDependency<ISendFileStorageService>()
|
||||
.Received(1)
|
||||
.GetSendFileUploadUrlAsync(send, Arg.Any<string>());
|
||||
|
||||
await sutProvider.GetDependency<ISendRepository>()
|
||||
.Received(1)
|
||||
.UpsertAsync(send);
|
||||
|
||||
await sutProvider.GetDependency<IPushNotificationService>()
|
||||
.Received(1)
|
||||
.PushSyncSendUpdateAsync(send);
|
||||
|
||||
await sutProvider.GetDependency<ISendFileStorageService>()
|
||||
.Received(1)
|
||||
.DeleteFileAsync(send, Arg.Any<string>());
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task UpdateFileToExistingSendAsync_SendNull_ThrowsBadRequest(SutProvider<SendService> sutProvider)
|
||||
{
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), null)
|
||||
);
|
||||
|
||||
Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task UpdateFileToExistingSendAsync_SendDataNull_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
send.Data = null;
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send)
|
||||
);
|
||||
|
||||
Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task UpdateFileToExistingSendAsync_NotFileType_ThrowsBadRequest(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send)
|
||||
);
|
||||
|
||||
Assert.Contains("not a file type send", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task UpdateFileToExistingSendAsync_Success(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var fileContents = "Test file content";
|
||||
|
||||
var sendFileData = new SendFileData
|
||||
{
|
||||
Id = "TEST",
|
||||
Size = fileContents.Length,
|
||||
Validated = false,
|
||||
};
|
||||
|
||||
send.Type = SendType.File;
|
||||
send.Data = JsonSerializer.Serialize(sendFileData);
|
||||
|
||||
sutProvider.GetDependency<ISendFileStorageService>()
|
||||
.ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any<long>())
|
||||
.Returns((true, sendFileData.Size));
|
||||
|
||||
await sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task UpdateFileToExistingSendAsync_InvalidSize(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var fileContents = "Test file content";
|
||||
|
||||
var sendFileData = new SendFileData
|
||||
{
|
||||
Id = "TEST",
|
||||
Size = fileContents.Length,
|
||||
};
|
||||
|
||||
send.Type = SendType.File;
|
||||
send.Data = JsonSerializer.Serialize(sendFileData);
|
||||
|
||||
sutProvider.GetDependency<ISendFileStorageService>()
|
||||
.ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any<long>())
|
||||
.Returns((false, sendFileData.Size));
|
||||
|
||||
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
|
||||
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send)
|
||||
);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public void SendCanBeAccessed_Success(SutProvider<SendService> sutProvider, Send send)
|
||||
{
|
||||
var now = DateTime.UtcNow;
|
||||
send.MaxAccessCount = 10;
|
||||
send.AccessCount = 5;
|
||||
send.ExpirationDate = now.AddYears(1);
|
||||
send.DeletionDate = now.AddYears(1);
|
||||
send.Disabled = false;
|
||||
|
||||
sutProvider.GetDependency<IPasswordHasher<User>>()
|
||||
.VerifyHashedPassword(Arg.Any<User>(), send.Password, "TEST")
|
||||
.Returns(PasswordVerificationResult.Success);
|
||||
|
||||
var (grant, passwordRequiredError, passwordInvalidError)
|
||||
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
|
||||
|
||||
Assert.True(grant);
|
||||
Assert.False(passwordRequiredError);
|
||||
Assert.False(passwordInvalidError);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public void SendCanBeAccessed_NullMaxAccess_Success(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var now = DateTime.UtcNow;
|
||||
send.MaxAccessCount = null;
|
||||
send.AccessCount = 5;
|
||||
send.ExpirationDate = now.AddYears(1);
|
||||
send.DeletionDate = now.AddYears(1);
|
||||
send.Disabled = false;
|
||||
|
||||
sutProvider.GetDependency<IPasswordHasher<User>>()
|
||||
.VerifyHashedPassword(Arg.Any<User>(), send.Password, "TEST")
|
||||
.Returns(PasswordVerificationResult.Success);
|
||||
|
||||
var (grant, passwordRequiredError, passwordInvalidError)
|
||||
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
|
||||
|
||||
Assert.True(grant);
|
||||
Assert.False(passwordRequiredError);
|
||||
Assert.False(passwordInvalidError);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public void SendCanBeAccessed_NullSend_DoesNotGrantAccess(SutProvider<SendService> sutProvider)
|
||||
{
|
||||
sutProvider.GetDependency<IPasswordHasher<User>>()
|
||||
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
|
||||
.Returns(PasswordVerificationResult.Success);
|
||||
|
||||
var (grant, passwordRequiredError, passwordInvalidError)
|
||||
= sutProvider.Sut.SendCanBeAccessed(null, "TEST");
|
||||
|
||||
Assert.False(grant);
|
||||
Assert.False(passwordRequiredError);
|
||||
Assert.False(passwordInvalidError);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public void SendCanBeAccessed_NullPassword_PasswordRequiredErrorReturnsTrue(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var now = DateTime.UtcNow;
|
||||
send.MaxAccessCount = null;
|
||||
send.AccessCount = 5;
|
||||
send.ExpirationDate = now.AddYears(1);
|
||||
send.DeletionDate = now.AddYears(1);
|
||||
send.Disabled = false;
|
||||
send.Password = "HASH";
|
||||
|
||||
sutProvider.GetDependency<IPasswordHasher<User>>()
|
||||
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
|
||||
.Returns(PasswordVerificationResult.Success);
|
||||
|
||||
var (grant, passwordRequiredError, passwordInvalidError)
|
||||
= sutProvider.Sut.SendCanBeAccessed(send, null);
|
||||
|
||||
Assert.False(grant);
|
||||
Assert.True(passwordRequiredError);
|
||||
Assert.False(passwordInvalidError);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public void SendCanBeAccessed_RehashNeeded_RehashesPassword(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var now = DateTime.UtcNow;
|
||||
send.MaxAccessCount = null;
|
||||
send.AccessCount = 5;
|
||||
send.ExpirationDate = now.AddYears(1);
|
||||
send.DeletionDate = now.AddYears(1);
|
||||
send.Disabled = false;
|
||||
send.Password = "TEST";
|
||||
|
||||
sutProvider.GetDependency<IPasswordHasher<User>>()
|
||||
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
|
||||
.Returns(PasswordVerificationResult.SuccessRehashNeeded);
|
||||
|
||||
var (grant, passwordRequiredError, passwordInvalidError)
|
||||
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
|
||||
|
||||
sutProvider.GetDependency<IPasswordHasher<User>>()
|
||||
.Received(1)
|
||||
.HashPassword(Arg.Any<User>(), "TEST");
|
||||
|
||||
Assert.True(grant);
|
||||
Assert.False(passwordRequiredError);
|
||||
Assert.False(passwordInvalidError);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public void SendCanBeAccessed_VerifyFailed_PasswordInvalidReturnsTrue(SutProvider<SendService> sutProvider,
|
||||
Send send)
|
||||
{
|
||||
var now = DateTime.UtcNow;
|
||||
send.MaxAccessCount = null;
|
||||
send.AccessCount = 5;
|
||||
send.ExpirationDate = now.AddYears(1);
|
||||
send.DeletionDate = now.AddYears(1);
|
||||
send.Disabled = false;
|
||||
send.Password = "TEST";
|
||||
|
||||
sutProvider.GetDependency<IPasswordHasher<User>>()
|
||||
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
|
||||
.Returns(PasswordVerificationResult.Failed);
|
||||
|
||||
var (grant, passwordRequiredError, passwordInvalidError)
|
||||
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
|
||||
|
||||
Assert.False(grant);
|
||||
Assert.False(passwordRequiredError);
|
||||
Assert.True(passwordInvalidError);
|
||||
}
|
||||
}
|
@ -57,8 +57,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
var localFactory = new IdentityApplicationFactory();
|
||||
var user = await localFactory.RegisterNewIdentityFactoryUserAsync(requestModel);
|
||||
|
||||
var context = await PostLoginAsync(localFactory.Server, user, requestModel.MasterPasswordHash,
|
||||
context => context.SetAuthEmail(user.Email));
|
||||
var context = await PostLoginAsync(localFactory.Server, user, requestModel.MasterPasswordHash);
|
||||
|
||||
using var body = await AssertDefaultTokenBodyAsync(context);
|
||||
var root = body.RootElement;
|
||||
@ -72,71 +71,6 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
AssertUserDecryptionOptions(root);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData, RegisterFinishRequestModelCustomize]
|
||||
public async Task TokenEndpoint_GrantTypePassword_NoAuthEmailHeader_Fails(
|
||||
RegisterFinishRequestModel requestModel)
|
||||
{
|
||||
requestModel.Email = "test+noauthemailheader@email.com";
|
||||
|
||||
var localFactory = new IdentityApplicationFactory();
|
||||
var user = await localFactory.RegisterNewIdentityFactoryUserAsync(requestModel);
|
||||
|
||||
var context = await PostLoginAsync(localFactory.Server, user, requestModel.MasterPasswordHash, null);
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
|
||||
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
var root = body.RootElement;
|
||||
|
||||
var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString();
|
||||
Assert.Equal("invalid_grant", error);
|
||||
AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData, RegisterFinishRequestModelCustomize]
|
||||
public async Task TokenEndpoint_GrantTypePassword_InvalidBase64AuthEmailHeader_Fails(
|
||||
RegisterFinishRequestModel requestModel)
|
||||
{
|
||||
requestModel.Email = "test+badauthheader@email.com";
|
||||
|
||||
var localFactory = new IdentityApplicationFactory();
|
||||
var user = await localFactory.RegisterNewIdentityFactoryUserAsync(requestModel);
|
||||
|
||||
var context = await PostLoginAsync(localFactory.Server, user, requestModel.MasterPasswordHash,
|
||||
context => context.Request.Headers.Append("Auth-Email", "bad_value"));
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
|
||||
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
var root = body.RootElement;
|
||||
|
||||
var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString();
|
||||
Assert.Equal("invalid_grant", error);
|
||||
AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData, RegisterFinishRequestModelCustomize]
|
||||
public async Task TokenEndpoint_GrantTypePassword_WrongAuthEmailHeader_Fails(
|
||||
RegisterFinishRequestModel requestModel)
|
||||
{
|
||||
requestModel.Email = "test+badauthheader@email.com";
|
||||
|
||||
var localFactory = new IdentityApplicationFactory();
|
||||
var user = await localFactory.RegisterNewIdentityFactoryUserAsync(requestModel);
|
||||
|
||||
var context = await PostLoginAsync(localFactory.Server, user, requestModel.MasterPasswordHash,
|
||||
context => context.SetAuthEmail("bad_value"));
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
|
||||
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
var root = body.RootElement;
|
||||
|
||||
var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString();
|
||||
Assert.Equal("invalid_grant", error);
|
||||
AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String);
|
||||
}
|
||||
|
||||
[Theory, RegisterFinishRequestModelCustomize]
|
||||
[BitAutoData(OrganizationUserType.Owner)]
|
||||
[BitAutoData(OrganizationUserType.Admin)]
|
||||
@ -157,8 +91,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
await CreateOrganizationWithSsoPolicyAsync(localFactory,
|
||||
organizationId, user.Email, organizationUserType, ssoPolicyEnabled: false);
|
||||
|
||||
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash,
|
||||
context => context.SetAuthEmail(user.Email));
|
||||
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash);
|
||||
|
||||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
}
|
||||
@ -184,8 +117,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
await CreateOrganizationWithSsoPolicyAsync(
|
||||
localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: false);
|
||||
|
||||
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash,
|
||||
context => context.SetAuthEmail(user.Email));
|
||||
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash);
|
||||
|
||||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
}
|
||||
@ -209,8 +141,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
|
||||
await CreateOrganizationWithSsoPolicyAsync(localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: true);
|
||||
|
||||
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash,
|
||||
context => context.SetAuthEmail(user.Email));
|
||||
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash);
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
await AssertRequiredSsoAuthenticationResponseAsync(context);
|
||||
@ -234,8 +165,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
|
||||
await CreateOrganizationWithSsoPolicyAsync(localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: true);
|
||||
|
||||
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash,
|
||||
context => context.SetAuthEmail(user.Email));
|
||||
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash);
|
||||
|
||||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
}
|
||||
@ -258,8 +188,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
|
||||
await CreateOrganizationWithSsoPolicyAsync(localFactory, organizationId, user.Email, organizationUserType, ssoPolicyEnabled: true);
|
||||
|
||||
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash,
|
||||
context => context.SetAuthEmail(user.Email));
|
||||
var context = await PostLoginAsync(server, user, requestModel.MasterPasswordHash);
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
await AssertRequiredSsoAuthenticationResponseAsync(context);
|
||||
@ -342,7 +271,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
{ "grant_type", "password" },
|
||||
{ "username", model.Email },
|
||||
{ "password", model.MasterPasswordHash },
|
||||
}), context => context.SetAuthEmail(model.Email));
|
||||
}));
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
|
||||
@ -554,12 +483,12 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
{ "grant_type", "password" },
|
||||
{ "username", user.Email},
|
||||
{ "password", "master_password_hash" },
|
||||
}), context => context.SetAuthEmail(user.Email).SetIp("1.1.1.2"));
|
||||
}), context => context.SetIp("1.1.1.2"));
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<HttpContext> PostLoginAsync(
|
||||
TestServer server, User user, string MasterPasswordHash, Action<HttpContext> extraConfiguration)
|
||||
TestServer server, User user, string MasterPasswordHash)
|
||||
{
|
||||
return await server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary<string, string>
|
||||
{
|
||||
@ -571,7 +500,7 @@ public class IdentityServerTests : IClassFixture<IdentityApplicationFactory>
|
||||
{ "grant_type", "password" },
|
||||
{ "username", user.Email },
|
||||
{ "password", MasterPasswordHash },
|
||||
}), extraConfiguration);
|
||||
}));
|
||||
}
|
||||
|
||||
private async Task CreateOrganizationWithSsoPolicyAsync(
|
||||
|
@ -143,7 +143,7 @@ public class IdentityServerTwoFactorTests : IClassFixture<IdentityApplicationFac
|
||||
{ "grant_type", "password" },
|
||||
{ "username", _testEmail },
|
||||
{ "password", _testPassword },
|
||||
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(_testEmail)));
|
||||
}));
|
||||
|
||||
// Assert
|
||||
using var responseBody = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
@ -263,7 +263,7 @@ public class IdentityServerTwoFactorTests : IClassFixture<IdentityApplicationFac
|
||||
{ "code", "test_code" },
|
||||
{ "code_verifier", challenge },
|
||||
{ "redirect_uri", "https://localhost:8080/sso-connector.html" }
|
||||
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(_testEmail)));
|
||||
}));
|
||||
|
||||
// Assert
|
||||
using var responseBody = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
@ -307,7 +307,7 @@ public class IdentityServerTwoFactorTests : IClassFixture<IdentityApplicationFac
|
||||
{ "code", "test_code" },
|
||||
{ "code_verifier", challenge },
|
||||
{ "redirect_uri", "https://localhost:8080/sso-connector.html" }
|
||||
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(_testEmail)));
|
||||
}));
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, failedTokenContext.Response.StatusCode);
|
||||
Assert.NotNull(emailToken);
|
||||
@ -326,7 +326,7 @@ public class IdentityServerTwoFactorTests : IClassFixture<IdentityApplicationFac
|
||||
{ "code", "test_code" },
|
||||
{ "code_verifier", challenge },
|
||||
{ "redirect_uri", "https://localhost:8080/sso-connector.html" }
|
||||
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(_testEmail)));
|
||||
}));
|
||||
|
||||
|
||||
// Assert
|
||||
@ -363,7 +363,7 @@ public class IdentityServerTwoFactorTests : IClassFixture<IdentityApplicationFac
|
||||
{ "code", "test_code" },
|
||||
{ "code_verifier", challenge },
|
||||
{ "redirect_uri", "https://localhost:8080/sso-connector.html" }
|
||||
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(_testEmail)));
|
||||
}));
|
||||
|
||||
// Assert
|
||||
using var responseBody = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
|
@ -29,8 +29,7 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
|
||||
|
||||
// Act
|
||||
var context = await localFactory.Server.PostAsync("/connect/token",
|
||||
GetFormUrlEncodedContent(),
|
||||
context => context.SetAuthEmail(DefaultUsername));
|
||||
GetFormUrlEncodedContent());
|
||||
|
||||
// Assert
|
||||
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
@ -40,27 +39,6 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
|
||||
Assert.NotNull(token);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ValidateAsync_AuthEmailHeaderInvalid_InvalidGrantResponse()
|
||||
{
|
||||
// Arrange
|
||||
var localFactory = new IdentityApplicationFactory();
|
||||
await EnsureUserCreatedAsync(localFactory);
|
||||
|
||||
// Act
|
||||
var context = await localFactory.Server.PostAsync(
|
||||
"/connect/token",
|
||||
GetFormUrlEncodedContent()
|
||||
);
|
||||
|
||||
// Assert
|
||||
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
var root = body.RootElement;
|
||||
|
||||
var error = AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String).GetString();
|
||||
Assert.Equal("Auth-Email header invalid.", error);
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
public async Task ValidateAsync_UserNull_Failure(string username)
|
||||
{
|
||||
@ -68,8 +46,7 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
|
||||
var localFactory = new IdentityApplicationFactory();
|
||||
// Act
|
||||
var context = await localFactory.Server.PostAsync("/connect/token",
|
||||
GetFormUrlEncodedContent(username: username),
|
||||
context => context.SetAuthEmail(username));
|
||||
GetFormUrlEncodedContent(username: username));
|
||||
|
||||
// Assert
|
||||
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
@ -106,8 +83,7 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
|
||||
|
||||
// Act
|
||||
var context = await localFactory.Server.PostAsync("/connect/token",
|
||||
GetFormUrlEncodedContent(password: badPassword),
|
||||
context => context.SetAuthEmail(DefaultUsername));
|
||||
GetFormUrlEncodedContent(password: badPassword));
|
||||
|
||||
// Assert
|
||||
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
@ -155,7 +131,7 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
|
||||
{ "username", DefaultUsername },
|
||||
{ "password", DefaultPassword },
|
||||
{ "AuthRequest", authRequest.Id.ToString().ToLowerInvariant() }
|
||||
}), context => context.SetAuthEmail(DefaultUsername));
|
||||
}));
|
||||
|
||||
// Assert
|
||||
var body = await AssertHelper.AssertResponseTypeIs<JsonDocument>(context);
|
||||
@ -197,7 +173,7 @@ public class ResourceOwnerPasswordValidatorTests : IClassFixture<IdentityApplica
|
||||
{ "username", DefaultUsername },
|
||||
{ "password", DefaultPassword },
|
||||
{ "AuthRequest", authRequest.Id.ToString().ToLowerInvariant() }
|
||||
}), context => context.SetAuthEmail(DefaultUsername));
|
||||
}));
|
||||
|
||||
// Assert
|
||||
|
||||
|
@ -6,6 +6,7 @@ using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Identity.IdentityServer;
|
||||
using Bit.Identity.Utilities;
|
||||
using Bit.Test.Common.AutoFixture.Attributes;
|
||||
using NSubstitute;
|
||||
using Xunit;
|
||||
@ -17,6 +18,7 @@ public class UserDecryptionOptionsBuilderTests
|
||||
private readonly ICurrentContext _currentContext;
|
||||
private readonly IDeviceRepository _deviceRepository;
|
||||
private readonly IOrganizationUserRepository _organizationUserRepository;
|
||||
private readonly ILoginApprovingClientTypes _loginApprovingClientTypes;
|
||||
private readonly UserDecryptionOptionsBuilder _builder;
|
||||
|
||||
public UserDecryptionOptionsBuilderTests()
|
||||
@ -24,7 +26,8 @@ public class UserDecryptionOptionsBuilderTests
|
||||
_currentContext = Substitute.For<ICurrentContext>();
|
||||
_deviceRepository = Substitute.For<IDeviceRepository>();
|
||||
_organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
|
||||
_builder = new UserDecryptionOptionsBuilder(_currentContext, _deviceRepository, _organizationUserRepository);
|
||||
_loginApprovingClientTypes = Substitute.For<ILoginApprovingClientTypes>();
|
||||
_builder = new UserDecryptionOptionsBuilder(_currentContext, _deviceRepository, _organizationUserRepository, _loginApprovingClientTypes);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
@ -120,17 +123,17 @@ public class UserDecryptionOptionsBuilderTests
|
||||
[BitAutoData(DeviceType.SafariBrowser)]
|
||||
[BitAutoData(DeviceType.VivaldiBrowser)]
|
||||
[BitAutoData(DeviceType.UnknownBrowser)]
|
||||
// Extension
|
||||
[BitAutoData(DeviceType.ChromeExtension)]
|
||||
[BitAutoData(DeviceType.FirefoxExtension)]
|
||||
[BitAutoData(DeviceType.OperaExtension)]
|
||||
[BitAutoData(DeviceType.EdgeExtension)]
|
||||
[BitAutoData(DeviceType.VivaldiExtension)]
|
||||
[BitAutoData(DeviceType.SafariExtension)]
|
||||
public async Task Build_WhenHasLoginApprovingDevice_ShouldApprovingDeviceTrue(
|
||||
DeviceType deviceType,
|
||||
SsoConfig ssoConfig, SsoConfigurationData configurationData, User user, Device device, Device approvingDevice)
|
||||
{
|
||||
_loginApprovingClientTypes.TypesThatCanApprove.Returns(new List<ClientType>
|
||||
{
|
||||
ClientType.Desktop,
|
||||
ClientType.Mobile,
|
||||
ClientType.Web,
|
||||
});
|
||||
|
||||
configurationData.MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption;
|
||||
ssoConfig.Data = configurationData.Serialize();
|
||||
approvingDevice.Type = deviceType;
|
||||
@ -142,9 +145,65 @@ public class UserDecryptionOptionsBuilderTests
|
||||
}
|
||||
|
||||
[Theory]
|
||||
// Desktop
|
||||
[BitAutoData(DeviceType.LinuxDesktop)]
|
||||
[BitAutoData(DeviceType.MacOsDesktop)]
|
||||
[BitAutoData(DeviceType.WindowsDesktop)]
|
||||
[BitAutoData(DeviceType.UWP)]
|
||||
// Mobile
|
||||
[BitAutoData(DeviceType.Android)]
|
||||
[BitAutoData(DeviceType.iOS)]
|
||||
[BitAutoData(DeviceType.AndroidAmazon)]
|
||||
// Web
|
||||
[BitAutoData(DeviceType.ChromeBrowser)]
|
||||
[BitAutoData(DeviceType.FirefoxBrowser)]
|
||||
[BitAutoData(DeviceType.OperaBrowser)]
|
||||
[BitAutoData(DeviceType.EdgeBrowser)]
|
||||
[BitAutoData(DeviceType.IEBrowser)]
|
||||
[BitAutoData(DeviceType.SafariBrowser)]
|
||||
[BitAutoData(DeviceType.VivaldiBrowser)]
|
||||
[BitAutoData(DeviceType.UnknownBrowser)]
|
||||
// Extension
|
||||
[BitAutoData(DeviceType.ChromeExtension)]
|
||||
[BitAutoData(DeviceType.FirefoxExtension)]
|
||||
[BitAutoData(DeviceType.OperaExtension)]
|
||||
[BitAutoData(DeviceType.EdgeExtension)]
|
||||
[BitAutoData(DeviceType.VivaldiExtension)]
|
||||
[BitAutoData(DeviceType.SafariExtension)]
|
||||
public async Task Build_WhenHasLoginApprovingDeviceFeatureFlag_ShouldApprovingDeviceTrue(
|
||||
DeviceType deviceType,
|
||||
SsoConfig ssoConfig, SsoConfigurationData configurationData, User user, Device device, Device approvingDevice)
|
||||
{
|
||||
_loginApprovingClientTypes.TypesThatCanApprove.Returns(new List<ClientType>
|
||||
{
|
||||
ClientType.Desktop,
|
||||
ClientType.Mobile,
|
||||
ClientType.Web,
|
||||
ClientType.Browser,
|
||||
});
|
||||
|
||||
configurationData.MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption;
|
||||
ssoConfig.Data = configurationData.Serialize();
|
||||
approvingDevice.Type = deviceType;
|
||||
_deviceRepository.GetManyByUserIdAsync(user.Id).Returns(new Device[] { approvingDevice });
|
||||
|
||||
var result = await _builder.ForUser(user).WithSso(ssoConfig).WithDevice(device).BuildAsync();
|
||||
|
||||
Assert.True(result.TrustedDeviceOption?.HasLoginApprovingDevice);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
// CLI
|
||||
[BitAutoData(DeviceType.WindowsCLI)]
|
||||
[BitAutoData(DeviceType.MacOsCLI)]
|
||||
[BitAutoData(DeviceType.LinuxCLI)]
|
||||
// Extension
|
||||
[BitAutoData(DeviceType.ChromeExtension)]
|
||||
[BitAutoData(DeviceType.FirefoxExtension)]
|
||||
[BitAutoData(DeviceType.OperaExtension)]
|
||||
[BitAutoData(DeviceType.EdgeExtension)]
|
||||
[BitAutoData(DeviceType.VivaldiExtension)]
|
||||
[BitAutoData(DeviceType.SafariExtension)]
|
||||
public async Task Build_WhenHasLoginApprovingDevice_ShouldApprovingDeviceFalse(
|
||||
DeviceType deviceType,
|
||||
SsoConfig ssoConfig, SsoConfigurationData configurationData, User user, Device device, Device approvingDevice)
|
||||
|
@ -5,10 +5,8 @@ using Bit.Core.Auth.Models.Api.Request.Accounts;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Utilities;
|
||||
using Bit.Identity;
|
||||
using Bit.Test.Common.Helpers;
|
||||
using HandlebarsDotNet;
|
||||
using LinqToDB;
|
||||
using Microsoft.AspNetCore.Hosting;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
@ -98,7 +96,7 @@ public class IdentityApplicationFactory : WebApplicationFactoryBase<Startup>
|
||||
{ "grant_type", "password" },
|
||||
{ "username", username },
|
||||
{ "password", password },
|
||||
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(username)));
|
||||
}));
|
||||
|
||||
return context;
|
||||
}
|
||||
@ -126,7 +124,7 @@ public class IdentityApplicationFactory : WebApplicationFactoryBase<Startup>
|
||||
{ "TwoFactorToken", twoFactorToken },
|
||||
{ "TwoFactorProvider", twoFactorProviderType },
|
||||
{ "TwoFactorRemember", "1" },
|
||||
}), context => context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(username)));
|
||||
}));
|
||||
|
||||
return context;
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
using System.Net;
|
||||
using Bit.Core.Utilities;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.TestHost;
|
||||
using Microsoft.Extensions.Primitives;
|
||||
@ -62,12 +61,6 @@ public static class WebApplicationFactoryExtensions
|
||||
Action<HttpContext> extraConfiguration = null)
|
||||
=> SendAsync(server, HttpMethod.Delete, requestUri, content: content, extraConfiguration);
|
||||
|
||||
public static HttpContext SetAuthEmail(this HttpContext context, string username)
|
||||
{
|
||||
context.Request.Headers.Append("Auth-Email", CoreHelpers.Base64UrlEncodeString(username));
|
||||
return context;
|
||||
}
|
||||
|
||||
public static HttpContext SetIp(this HttpContext context, string ip)
|
||||
{
|
||||
context.Connection.RemoteIpAddress = IPAddress.Parse(ip);
|
||||
|
Loading…
x
Reference in New Issue
Block a user