1
0
mirror of https://github.com/bitwarden/server.git synced 2025-05-24 04:51:03 -05:00

Merge branch 'main' into pm-20621-Update-error-message-when-lowering-seat-count

This commit is contained in:
cyprain-okeke 2025-05-20 12:23:46 +01:00 committed by GitHub
commit 0ad34ccb23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 4595 additions and 2031 deletions

View File

@ -3,7 +3,7 @@
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<Version>2025.5.0</Version>
<Version>2025.5.1</Version>
<RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace>
<ImplicitUsings>enable</ImplicitUsings>

View File

@ -8,13 +8,10 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Billing.Tax.Services.Implementations;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.Extensions.DependencyInjection;
using Stripe;
namespace Bit.Commercial.Core.AdminConsole.Providers;
@ -24,7 +21,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
private readonly IEventService _eventService;
private readonly IMailService _mailService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationService _organizationService;
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IStripeAdapter _stripeAdapter;
private readonly IFeatureService _featureService;
@ -32,26 +28,22 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
private readonly ISubscriberService _subscriberService;
private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery;
private readonly IPricingClient _pricingClient;
private readonly IAutomaticTaxStrategy _automaticTaxStrategy;
public RemoveOrganizationFromProviderCommand(
IEventService eventService,
IMailService mailService,
IOrganizationRepository organizationRepository,
IOrganizationService organizationService,
IProviderOrganizationRepository providerOrganizationRepository,
IStripeAdapter stripeAdapter,
IFeatureService featureService,
IProviderBillingService providerBillingService,
ISubscriberService subscriberService,
IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery,
IPricingClient pricingClient,
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
IPricingClient pricingClient)
{
_eventService = eventService;
_mailService = mailService;
_organizationRepository = organizationRepository;
_organizationService = organizationService;
_providerOrganizationRepository = providerOrganizationRepository;
_stripeAdapter = stripeAdapter;
_featureService = featureService;
@ -59,7 +51,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
_subscriberService = subscriberService;
_hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery;
_pricingClient = pricingClient;
_automaticTaxStrategy = automaticTaxStrategy;
}
public async Task RemoveOrganizationFromProvider(
@ -77,7 +68,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync(
providerOrganization.OrganizationId,
Array.Empty<Guid>(),
[],
includeProvider: false))
{
throw new BadRequestException("Organization must have at least one confirmed owner.");
@ -102,7 +93,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
/// <summary>
/// When a client organization is unlinked from a provider, we have to check if they're Stripe-enabled
/// and, if they are, we remove their MSP discount and set their Subscription to `send_invoice`. This is because
/// the provider's payment method will be removed from their Stripe customer causing ensuing charges to fail. Lastly,
/// the provider's payment method will be removed from their Stripe customer, causing ensuing charges to fail. Lastly,
/// we email the organization owners letting them know they need to add a new payment method.
/// </summary>
private async Task ResetOrganizationBillingAsync(
@ -142,15 +133,18 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }]
};
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
var setNonUSBusinessUseToReverseCharge = _featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge)
{
_automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
}
else
else if (customer.HasRecognizedTaxLocation())
{
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
Enabled = customer.Address.Country == "US" ||
customer.TaxIds.Any()
};
}
@ -187,7 +181,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
await _mailService.SendProviderUpdatePaymentMethod(
organization.Id,
organization.Name,
provider.Name,
provider.Name!,
organizationOwnerEmails);
}
}

View File

@ -18,7 +18,6 @@ using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Billing.Tax.Services.Implementations;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
@ -27,7 +26,6 @@ using Bit.Core.Services;
using Bit.Core.Settings;
using Braintree;
using CsvHelper;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Stripe;
@ -52,8 +50,7 @@ public class ProviderBillingService(
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
ITaxService taxService,
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
ITaxService taxService)
: IProviderBillingService
{
public async Task AddExistingOrganization(
@ -128,7 +125,7 @@ public class ProviderBillingService(
/*
* We have to scale the provider's seats before the ProviderOrganization
* row is inserted so the added organization's seats don't get double counted.
* row is inserted so the added organization's seats don't get double-counted.
*/
await ScaleSeats(provider, organization.PlanType, organization.Seats!.Value);
@ -236,7 +233,7 @@ public class ProviderBillingService(
var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions
{
Expand = ["tax_ids"]
Expand = ["tax", "tax_ids"]
});
var providerTaxId = providerCustomer.TaxIds.FirstOrDefault();
@ -284,6 +281,13 @@ public class ProviderBillingService(
]
};
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge && providerCustomer.Address is not { Country: "US" })
{
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions);
organization.GatewayCustomerId = customer.Id;
@ -520,6 +524,13 @@ public class ProviderBillingService(
}
};
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge && taxInfo.BillingAddressCountry != "US")
{
options.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber))
{
var taxIdType = taxService.GetStripeTaxCode(
@ -531,6 +542,7 @@ public class ProviderBillingService(
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
taxInfo.BillingAddressCountry,
taxInfo.TaxIdNumber);
throw new BadRequestException("billingTaxIdTypeInferenceError");
}
@ -718,14 +730,21 @@ public class ProviderBillingService(
TrialPeriodDays = trialPeriodDays
};
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge)
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
}
else if (customer.HasRecognizedTaxLocation())
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = customer.Address.Country == "US" ||
customer.TaxIds.Any()
};
}
try
{

View File

@ -1,4 +1,5 @@
using Bit.Commercial.Core.AdminConsole.Providers;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
@ -8,7 +9,6 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
@ -224,31 +224,115 @@ public class RemoveOrganizationFromProviderCommandTests
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(options =>
options.Description == string.Empty &&
options.Email == organization.BillingEmail &&
options.Expand[0] == "tax" &&
options.Expand[1] == "tax_ids")).Returns(new Customer
{
Id = "customer_id",
Address = new Address
{
Country = "US"
}
});
stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(new Subscription
{
Id = "subscription_id"
});
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization);
await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == organization.GatewayCustomerId &&
options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice &&
options.DaysUntilDue == 30 &&
options.AutomaticTax.Enabled == true &&
options.Metadata["organizationId"] == organization.Id.ToString() &&
options.OffSession == true &&
options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId &&
options.Items.First().Quantity == organization.Seats)
, Arg.Any<Customer>()))
.Do(x =>
options.Items.First().Quantity == organization.Seats));
await sutProvider.GetDependency<IProviderBillingService>().Received(1)
.ScaleSeats(provider, organization.PlanType, -organization.Seats ?? 0);
await organizationRepository.Received(1).ReplaceAsync(Arg.Is<Organization>(
org =>
org.BillingEmail == "a@example.com" &&
org.GatewaySubscriptionId == "subscription_id" &&
org.Status == OrganizationStatusType.Created));
await sutProvider.GetDependency<IProviderOrganizationRepository>().Received(1)
.DeleteAsync(providerOrganization);
await sutProvider.GetDependency<IEventService>().Received(1)
.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed);
await sutProvider.GetDependency<IMailService>().Received(1)
.SendProviderUpdatePaymentMethod(
organization.Id,
organization.Name,
provider.Name,
Arg.Is<IEnumerable<string>>(emails => emails.FirstOrDefault() == "a@example.com"));
}
[Theory, BitAutoData]
public async Task RemoveOrganizationFromProvider_OrganizationStripeEnabled_ConsolidatedBilling_ReverseCharge_MakesCorrectInvocations(
Provider provider,
ProviderOrganization providerOrganization,
Organization organization,
SutProvider<RemoveOrganizationFromProviderCommand> sutProvider)
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
provider.Status = ProviderStatusType.Billable;
providerOrganization.ProviderId = provider.Id;
organization.Status = OrganizationStatusType.Managed;
organization.PlanType = PlanType.TeamsMonthly;
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan);
sutProvider.GetDependency<IHasConfirmedOwnersExceptQuery>().HasConfirmedOwnersExceptAsync(
providerOrganization.OrganizationId,
[],
includeProvider: false)
.Returns(true);
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns([
"a@example.com",
"b@example.com"
]);
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(options =>
options.Description == string.Empty &&
options.Email == organization.BillingEmail &&
options.Expand[0] == "tax" &&
options.Expand[1] == "tax_ids")).Returns(new Customer
{
Enabled = true
};
Id = "customer_id",
Address = new Address
{
Country = "US"
}
});
stripeAdapter.SubscriptionCreateAsync(Arg.Any<SubscriptionCreateOptions>()).Returns(new Subscription
{
Id = "subscription_id"
});
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization);
await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options =>

View File

@ -262,7 +262,7 @@ public class ProviderBillingServiceTests
};
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>(
options => options.Expand.FirstOrDefault() == "tax_ids"))
options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids")))
.Returns(providerCustomer);
sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri
@ -312,6 +312,91 @@ public class ProviderBillingServiceTests
org => org.GatewayCustomerId == "customer_id"));
}
[Theory, BitAutoData]
public async Task CreateCustomer_ForClientOrg_ReverseCharge_Succeeds(
Provider provider,
Organization organization,
SutProvider<ProviderBillingService> sutProvider)
{
organization.GatewayCustomerId = null;
organization.Name = "Name";
organization.BusinessName = "BusinessName";
var providerCustomer = new Customer
{
Address = new Address
{
Country = "CA",
PostalCode = "12345",
Line1 = "123 Main St.",
Line2 = "Unit 4",
City = "Fake Town",
State = "Fake State"
},
TaxIds = new StripeList<TaxId>
{
Data =
[
new TaxId { Type = "TYPE", Value = "VALUE" }
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider, Arg.Is<CustomerGetOptions>(
options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids")))
.Returns(providerCustomer);
sutProvider.GetDependency<IGlobalSettings>().BaseServiceUri
.Returns(new Bit.Core.Settings.GlobalSettings.BaseServiceUriSettings(new Bit.Core.Settings.GlobalSettings())
{
CloudRegion = "US"
});
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
sutProvider.GetDependency<IStripeAdapter>().CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(
options =>
options.Address.Country == providerCustomer.Address.Country &&
options.Address.PostalCode == providerCustomer.Address.PostalCode &&
options.Address.Line1 == providerCustomer.Address.Line1 &&
options.Address.Line2 == providerCustomer.Address.Line2 &&
options.Address.City == providerCustomer.Address.City &&
options.Address.State == providerCustomer.Address.State &&
options.Name == organization.DisplayName() &&
options.Description == $"{provider.Name} Client Organization" &&
options.Email == provider.BillingEmail &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" &&
options.Metadata["region"] == "US" &&
options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type &&
options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value &&
options.TaxExempt == StripeConstants.TaxExempt.Reverse))
.Returns(new Customer { Id = "customer_id" });
await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization);
await sutProvider.GetDependency<IStripeAdapter>().Received(1).CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(
options =>
options.Address.Country == providerCustomer.Address.Country &&
options.Address.PostalCode == providerCustomer.Address.PostalCode &&
options.Address.Line1 == providerCustomer.Address.Line1 &&
options.Address.Line2 == providerCustomer.Address.Line2 &&
options.Address.City == providerCustomer.Address.City &&
options.Address.State == providerCustomer.Address.State &&
options.Name == organization.DisplayName() &&
options.Description == $"{provider.Name} Client Organization" &&
options.Email == provider.BillingEmail &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" &&
options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" &&
options.Metadata["region"] == "US" &&
options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type &&
options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value));
await sutProvider.GetDependency<IOrganizationRepository>().Received(1).ReplaceAsync(Arg.Is<Organization>(
org => org.GatewayCustomerId == "customer_id"));
}
#endregion
#region GenerateClientInvoiceReport
@ -1182,6 +1267,62 @@ public class ProviderBillingServiceTests
Assert.Equivalent(expected, actual);
}
[Theory, BitAutoData]
public async Task SetupCustomer_WithCard_ReverseCharge_Success(
SutProvider<ProviderBillingService> sutProvider,
Provider provider,
TaxInfo taxInfo)
{
provider.Name = "MSP";
sutProvider.GetDependency<ITaxService>()
.GetStripeTaxCode(Arg.Is<string>(
p => p == taxInfo.BillingAddressCountry),
Arg.Is<string>(p => p == taxInfo.TaxIdNumber))
.Returns(taxInfo.TaxIdType);
taxInfo.BillingAddressCountry = "AD";
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var expected = new Customer
{
Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
};
var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token");
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
stripeAdapter.CustomerCreateAsync(Arg.Is<CustomerCreateOptions>(o =>
o.Address.Country == taxInfo.BillingAddressCountry &&
o.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
o.Address.Line1 == taxInfo.BillingAddressLine1 &&
o.Address.Line2 == taxInfo.BillingAddressLine2 &&
o.Address.City == taxInfo.BillingAddressCity &&
o.Address.State == taxInfo.BillingAddressState &&
o.Description == WebUtility.HtmlDecode(provider.BusinessName) &&
o.Email == provider.BillingEmail &&
o.PaymentMethod == tokenizedPaymentSource.Token &&
o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" &&
o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" &&
o.Metadata["region"] == "" &&
o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType &&
o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber &&
o.TaxExempt == StripeConstants.TaxExempt.Reverse))
.Returns(expected);
var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource);
Assert.Equivalent(expected, actual);
}
[Theory, BitAutoData]
public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid(
SutProvider<ProviderBillingService> sutProvider,
@ -1307,7 +1448,7 @@ public class ProviderBillingServiceTests
.Returns(new Customer
{
Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
Address = new Address { Country = "US" }
});
var providerPlans = new List<ProviderPlan>
@ -1359,7 +1500,7 @@ public class ProviderBillingServiceTests
var customer = new Customer
{
Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
Address = new Address { Country = "US" }
};
sutProvider.GetDependency<ISubscriberService>()
.GetCustomerOrThrow(
@ -1399,19 +1540,6 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
sub =>
sub.AutomaticTax.Enabled == true &&
@ -1443,11 +1571,11 @@ public class ProviderBillingServiceTests
var customer = new Customer
{
Id = "customer_id",
Address = new Address { Country = "US" },
InvoiceSettings = new CustomerInvoiceSettings
{
DefaultPaymentMethodId = "pm_123"
},
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}
};
sutProvider.GetDependency<ISubscriberService>()
@ -1488,19 +1616,6 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
@ -1536,9 +1651,9 @@ public class ProviderBillingServiceTests
var customer = new Customer
{
Id = "customer_id",
Address = new Address { Country = "US" },
InvoiceSettings = new CustomerInvoiceSettings(),
Metadata = new Dictionary<string, string>(),
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
Metadata = new Dictionary<string, string>()
};
sutProvider.GetDependency<ISubscriberService>()
@ -1579,19 +1694,6 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
@ -1646,12 +1748,15 @@ public class ProviderBillingServiceTests
var customer = new Customer
{
Id = "customer_id",
Address = new Address
{
Country = "US"
},
InvoiceSettings = new CustomerInvoiceSettings(),
Metadata = new Dictionary<string, string>
{
["btCustomerId"] = "braintree_customer_id"
},
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
}
};
sutProvider.GetDependency<ISubscriberService>()
@ -1692,22 +1797,92 @@ public class ProviderBillingServiceTests
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
sub =>
sub.AutomaticTax.Enabled == true &&
sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically &&
sub.Customer == "customer_id" &&
sub.DaysUntilDue == null &&
sub.Items.Count == 2 &&
sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams &&
sub.Items.ElementAt(0).Quantity == 100 &&
sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise &&
sub.Items.ElementAt(1).Quantity == 100 &&
sub.Metadata["providerId"] == provider.Id.ToString() &&
sub.OffSession == true &&
sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
sub.TrialPeriodDays == 14)).Returns(expected);
var actual = await sutProvider.Sut.SetupSubscription(provider);
Assert.Equivalent(expected, actual);
}
[Theory, BitAutoData]
public async Task SetupSubscription_ReverseCharge_Succeeds(
SutProvider<ProviderBillingService> sutProvider,
Provider provider)
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
provider.Type = ProviderType.Msp;
provider.GatewaySubscriptionId = null;
var customer = new Customer
{
Enabled = true
Id = "customer_id",
Address = new Address { Country = "CA" },
InvoiceSettings = new CustomerInvoiceSettings
{
DefaultPaymentMethodId = "pm_123"
}
};
});
sutProvider.GetDependency<ISubscriberService>()
.GetCustomerOrThrow(
provider,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer);
var providerPlans = new List<ProviderPlan>
{
new()
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.TeamsMonthly,
SeatMinimum = 100,
PurchasedSeats = 0,
AllocatedSeats = 0
},
new()
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.EnterpriseMonthly,
SeatMinimum = 100,
PurchasedSeats = 0,
AllocatedSeats = 0
}
};
foreach (var plan in providerPlans)
{
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(plan.PlanType)
.Returns(StaticStore.GetPlan(plan.PlanType));
}
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id)
.Returns(providerPlans);
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
sub =>
sub.AutomaticTax.Enabled == true &&

View File

@ -2,7 +2,7 @@
using Bit.Core;
using Bit.Core.Jobs;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.Services;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Quartz;
namespace Bit.Admin.Tools.Jobs;
@ -32,10 +32,10 @@ public class DeleteSendsJob : BaseJob
}
using (var scope = _serviceProvider.CreateScope())
{
var sendService = scope.ServiceProvider.GetRequiredService<ISendService>();
var nonAnonymousSendCommand = scope.ServiceProvider.GetRequiredService<INonAnonymousSendCommand>();
foreach (var send in sends)
{
await sendService.DeleteSendAsync(send);
await nonAnonymousSendCommand.DeleteSendAsync(send);
}
}
}

View File

@ -292,15 +292,17 @@ public class OrganizationBillingController(
sale.Organization.PlanType = plan.Type;
sale.Organization.Plan = plan.Name;
sale.SubscriptionSetup.SkipTrial = true;
await organizationBillingService.Finalize(sale);
if (organizationSignup.PaymentMethodType == null || string.IsNullOrEmpty(organizationSignup.PaymentToken))
{
return Error.BadRequest("A payment method is required to restart the subscription.");
}
var org = await organizationRepository.GetByIdAsync(organizationId);
Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine.");
if (organizationSignup.PaymentMethodType != null)
{
var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken);
var taxInformation = TaxInformation.From(organizationSignup.TaxInfo);
await organizationBillingService.UpdatePaymentMethod(org, paymentSource, taxInformation);
}
await organizationBillingService.Finalize(sale);
return TypedResults.Ok();
}

View File

@ -109,28 +109,6 @@ public class OrganizationsController(
return license;
}
[HttpPost("{id:guid}/payment")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task PostPayment(Guid id, [FromBody] PaymentRequestModel model)
{
if (!await currentContext.EditPaymentMethods(id))
{
throw new NotFoundException();
}
await organizationService.ReplacePaymentMethodAsync(id, model.PaymentToken,
model.PaymentMethodType.Value, new TaxInfo
{
BillingAddressLine1 = model.Line1,
BillingAddressLine2 = model.Line2,
BillingAddressState = model.State,
BillingAddressCity = model.City,
BillingAddressPostalCode = model.PostalCode,
BillingAddressCountry = model.Country,
TaxIdNumber = model.TaxId,
});
}
[HttpPost("{id:guid}/upgrade")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<PaymentResponseModel> PostUpgrade(Guid id, [FromBody] OrganizationUpgradeRequestModel model)

View File

@ -12,17 +12,17 @@ namespace Bit.Api.KeyManagement.Validators;
/// </summary>
public class SendRotationValidator : IRotationValidator<IEnumerable<SendWithIdRequestModel>, IReadOnlyList<Send>>
{
private readonly ISendService _sendService;
private readonly ISendAuthorizationService _sendAuthorizationService;
private readonly ISendRepository _sendRepository;
/// <summary>
/// Instantiates a new <see cref="SendRotationValidator"/>
/// </summary>
/// <param name="sendService">Enables conversion of <see cref="SendWithIdRequestModel"/> to <see cref="Send"/></param>
/// <param name="sendAuthorizationService">Enables conversion of <see cref="SendWithIdRequestModel"/> to <see cref="Send"/></param>
/// <param name="sendRepository">Retrieves all user <see cref="Send"/>s</param>
public SendRotationValidator(ISendService sendService, ISendRepository sendRepository)
public SendRotationValidator(ISendAuthorizationService sendAuthorizationService, ISendRepository sendRepository)
{
_sendService = sendService;
_sendAuthorizationService = sendAuthorizationService;
_sendRepository = sendRepository;
}
@ -44,7 +44,7 @@ public class SendRotationValidator : IRotationValidator<IEnumerable<SendWithIdRe
throw new BadRequestException("All existing sends must be included in the rotation.");
}
result.Add(send.ToSend(existing, _sendService));
result.Add(send.ToSend(existing, _sendAuthorizationService));
}
return result;

View File

@ -35,6 +35,7 @@ using Bit.Core.Services;
using Bit.Core.Tools.ImportFeatures;
using Bit.Core.Tools.ReportFeatures;
using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Tools.SendFeatures;
#if !OSS
using Bit.Commercial.Core.SecretsManager;
@ -186,6 +187,7 @@ public class Startup
services.AddPhishingDomainServices(globalSettings);
services.AddBillingQueries();
services.AddSendServices();
// Authorization Handlers
services.AddAuthorizationHandlers();

View File

@ -12,6 +12,8 @@ using Bit.Core.Settings;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
@ -25,8 +27,10 @@ public class SendsController : Controller
{
private readonly ISendRepository _sendRepository;
private readonly IUserService _userService;
private readonly ISendService _sendService;
private readonly ISendAuthorizationService _sendAuthorizationService;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IAnonymousSendCommand _anonymousSendCommand;
private readonly INonAnonymousSendCommand _nonAnonymousSendCommand;
private readonly ILogger<SendsController> _logger;
private readonly GlobalSettings _globalSettings;
private readonly ICurrentContext _currentContext;
@ -34,7 +38,9 @@ public class SendsController : Controller
public SendsController(
ISendRepository sendRepository,
IUserService userService,
ISendService sendService,
ISendAuthorizationService sendAuthorizationService,
IAnonymousSendCommand anonymousSendCommand,
INonAnonymousSendCommand nonAnonymousSendCommand,
ISendFileStorageService sendFileStorageService,
ILogger<SendsController> logger,
GlobalSettings globalSettings,
@ -42,13 +48,16 @@ public class SendsController : Controller
{
_sendRepository = sendRepository;
_userService = userService;
_sendService = sendService;
_sendAuthorizationService = sendAuthorizationService;
_anonymousSendCommand = anonymousSendCommand;
_nonAnonymousSendCommand = nonAnonymousSendCommand;
_sendFileStorageService = sendFileStorageService;
_logger = logger;
_globalSettings = globalSettings;
_currentContext = currentContext;
}
#region Anonymous endpoints
[AllowAnonymous]
[HttpPost("access/{id}")]
public async Task<IActionResult> Access(string id, [FromBody] SendAccessRequestModel model)
@ -61,18 +70,19 @@ public class SendsController : Controller
//}
var guid = new Guid(CoreHelpers.Base64UrlDecode(id));
var (send, passwordRequired, passwordInvalid) =
await _sendService.AccessAsync(guid, model.Password);
if (passwordRequired)
var send = await _sendRepository.GetByIdAsync(guid);
SendAccessResult sendAuthResult =
await _sendAuthorizationService.AccessAsync(send, model.Password);
if (sendAuthResult.Equals(SendAccessResult.PasswordRequired))
{
return new UnauthorizedResult();
}
if (passwordInvalid)
if (sendAuthResult.Equals(SendAccessResult.PasswordInvalid))
{
await Task.Delay(2000);
throw new BadRequestException("Invalid password.");
}
if (send == null)
if (sendAuthResult.Equals(SendAccessResult.Denied))
{
throw new NotFoundException();
}
@ -106,19 +116,19 @@ public class SendsController : Controller
throw new BadRequestException("Could not locate send");
}
var (url, passwordRequired, passwordInvalid) = await _sendService.GetSendFileDownloadUrlAsync(send, fileId,
var (url, result) = await _anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId,
model.Password);
if (passwordRequired)
if (result.Equals(SendAccessResult.PasswordRequired))
{
return new UnauthorizedResult();
}
if (passwordInvalid)
if (result.Equals(SendAccessResult.PasswordInvalid))
{
await Task.Delay(2000);
throw new BadRequestException("Invalid password.");
}
if (send == null)
if (result.Equals(SendAccessResult.Denied))
{
throw new NotFoundException();
}
@ -130,6 +140,45 @@ public class SendsController : Controller
});
}
[AllowAnonymous]
[HttpPost("file/validate/azure")]
public async Task<ObjectResult> AzureValidateFile()
{
return await ApiHelpers.HandleAzureEvents(Request, new Dictionary<string, Func<EventGridEvent, Task>>
{
{
"Microsoft.Storage.BlobCreated", async (eventGridEvent) =>
{
try
{
var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1];
var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName);
var send = await _sendRepository.GetByIdAsync(new Guid(sendId));
if (send == null)
{
if (_sendFileStorageService is AzureSendFileStorageService azureSendFileStorageService)
{
await azureSendFileStorageService.DeleteBlobAsync(blobName);
}
return;
}
await _nonAnonymousSendCommand.ConfirmFileSize(send);
}
catch (Exception e)
{
_logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}");
return;
}
}
}
});
}
#endregion
#region Non-anonymous endpoints
[HttpGet("{id}")]
public async Task<SendResponseModel> Get(string id)
{
@ -157,8 +206,8 @@ public class SendsController : Controller
{
model.ValidateCreation();
var userId = _userService.GetProperUserId(User).Value;
var send = model.ToSend(userId, _sendService);
await _sendService.SaveSendAsync(send);
var send = model.ToSend(userId, _sendAuthorizationService);
await _nonAnonymousSendCommand.SaveSendAsync(send);
return new SendResponseModel(send, _globalSettings);
}
@ -175,15 +224,15 @@ public class SendsController : Controller
throw new BadRequestException("Invalid content. File size hint is required.");
}
if (model.FileLength.Value > SendService.MAX_FILE_SIZE)
if (model.FileLength.Value > Constants.FileSize501mb)
{
throw new BadRequestException($"Max file size is {SendService.MAX_FILE_SIZE_READABLE}.");
throw new BadRequestException($"Max file size is {SendFileSettingHelper.MAX_FILE_SIZE_READABLE}.");
}
model.ValidateCreation();
var userId = _userService.GetProperUserId(User).Value;
var (send, data) = model.ToSend(userId, model.File.FileName, _sendService);
var uploadUrl = await _sendService.SaveFileSendAsync(send, data, model.FileLength.Value);
var (send, data) = model.ToSend(userId, model.File.FileName, _sendAuthorizationService);
var uploadUrl = await _nonAnonymousSendCommand.SaveFileSendAsync(send, data, model.FileLength.Value);
return new SendFileUploadDataResponseModel
{
Url = uploadUrl,
@ -230,41 +279,7 @@ public class SendsController : Controller
var send = await _sendRepository.GetByIdAsync(new Guid(id));
await Request.GetFileAsync(async (stream) =>
{
await _sendService.UploadFileToExistingSendAsync(stream, send);
});
}
[AllowAnonymous]
[HttpPost("file/validate/azure")]
public async Task<ObjectResult> AzureValidateFile()
{
return await ApiHelpers.HandleAzureEvents(Request, new Dictionary<string, Func<EventGridEvent, Task>>
{
{
"Microsoft.Storage.BlobCreated", async (eventGridEvent) =>
{
try
{
var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1];
var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName);
var send = await _sendRepository.GetByIdAsync(new Guid(sendId));
if (send == null)
{
if (_sendFileStorageService is AzureSendFileStorageService azureSendFileStorageService)
{
await azureSendFileStorageService.DeleteBlobAsync(blobName);
}
return;
}
await _sendService.ValidateSendFile(send);
}
catch (Exception e)
{
_logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}");
return;
}
}
}
await _nonAnonymousSendCommand.UploadFileToExistingSendAsync(stream, send);
});
}
@ -279,7 +294,7 @@ public class SendsController : Controller
throw new NotFoundException();
}
await _sendService.SaveSendAsync(model.ToSend(send, _sendService));
await _nonAnonymousSendCommand.SaveSendAsync(model.ToSend(send, _sendAuthorizationService));
return new SendResponseModel(send, _globalSettings);
}
@ -294,7 +309,7 @@ public class SendsController : Controller
}
send.Password = null;
await _sendService.SaveSendAsync(send);
await _nonAnonymousSendCommand.SaveSendAsync(send);
return new SendResponseModel(send, _globalSettings);
}
@ -308,6 +323,8 @@ public class SendsController : Controller
throw new NotFoundException();
}
await _sendService.DeleteSendAsync(send);
await _nonAnonymousSendCommand.DeleteSendAsync(send);
}
#endregion
}

View File

@ -36,31 +36,31 @@ public class SendRequestModel
public bool? Disabled { get; set; }
public bool? HideEmail { get; set; }
public Send ToSend(Guid userId, ISendService sendService)
public Send ToSend(Guid userId, ISendAuthorizationService sendAuthorizationService)
{
var send = new Send
{
Type = Type,
UserId = (Guid?)userId
};
ToSend(send, sendService);
ToSend(send, sendAuthorizationService);
return send;
}
public (Send, SendFileData) ToSend(Guid userId, string fileName, ISendService sendService)
public (Send, SendFileData) ToSend(Guid userId, string fileName, ISendAuthorizationService sendAuthorizationService)
{
var send = ToSendBase(new Send
{
Type = Type,
UserId = (Guid?)userId
}, sendService);
}, sendAuthorizationService);
var data = new SendFileData(Name, Notes, fileName);
return (send, data);
}
public Send ToSend(Send existingSend, ISendService sendService)
public Send ToSend(Send existingSend, ISendAuthorizationService sendAuthorizationService)
{
existingSend = ToSendBase(existingSend, sendService);
existingSend = ToSendBase(existingSend, sendAuthorizationService);
switch (existingSend.Type)
{
case SendType.File:
@ -125,7 +125,7 @@ public class SendRequestModel
}
}
private Send ToSendBase(Send existingSend, ISendService sendService)
private Send ToSendBase(Send existingSend, ISendAuthorizationService authorizationService)
{
existingSend.Key = Key;
existingSend.ExpirationDate = ExpirationDate;
@ -133,7 +133,7 @@ public class SendRequestModel
existingSend.MaxAccessCount = MaxAccessCount;
if (!string.IsNullOrWhiteSpace(Password))
{
existingSend.Password = sendService.HashPassword(Password);
existingSend.Password = authorizationService.HashPassword(Password);
}
existingSend.Disabled = Disabled.GetValueOrDefault();
existingSend.HideEmail = HideEmail.GetValueOrDefault();

View File

@ -1,11 +1,11 @@
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Services;
@ -25,8 +25,7 @@ public class UpcomingInvoiceHandler(
IStripeEventService stripeEventService,
IStripeEventUtilityService stripeEventUtilityService,
IUserRepository userRepository,
IValidateSponsorshipCommand validateSponsorshipCommand,
IAutomaticTaxFactory automaticTaxFactory)
IValidateSponsorshipCommand validateSponsorshipCommand)
: IUpcomingInvoiceHandler
{
public async Task HandleAsync(Event parsedEvent)
@ -46,6 +45,8 @@ public class UpcomingInvoiceHandler(
var (organizationId, userId, providerId) = stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata);
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (organizationId.HasValue)
{
var organization = await organizationRepository.GetByIdAsync(organizationId.Value);
@ -55,7 +56,7 @@ public class UpcomingInvoiceHandler(
return;
}
await TryEnableAutomaticTaxAsync(subscription);
await AlignOrganizationTaxConcernsAsync(organization, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge);
var plan = await pricingClient.GetPlanOrThrow(organization.PlanType);
@ -100,7 +101,25 @@ public class UpcomingInvoiceHandler(
return;
}
await TryEnableAutomaticTaxAsync(subscription);
if (!subscription.AutomaticTax.Enabled && subscription.Customer.HasRecognizedTaxLocation())
{
try
{
await stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set user's ({UserID}) subscription to automatic tax while processing event with ID {EventID}",
user.Id,
parsedEvent.Id);
}
}
if (user.Premium)
{
@ -116,7 +135,7 @@ public class UpcomingInvoiceHandler(
return;
}
await TryEnableAutomaticTaxAsync(subscription);
await AlignProviderTaxConcernsAsync(provider, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge);
await SendUpcomingInvoiceEmailsAsync(new List<string> { provider.BillingEmail }, invoice);
}
@ -139,50 +158,123 @@ public class UpcomingInvoiceHandler(
}
}
private async Task TryEnableAutomaticTaxAsync(Subscription subscription)
private async Task AlignOrganizationTaxConcernsAsync(
Organization organization,
Subscription subscription,
string eventId,
bool setNonUSBusinessUseToReverseCharge)
{
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id));
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
var nonUSBusinessUse =
organization.PlanType.GetProductTier() != ProductTierType.Families &&
subscription.Customer.Address.Country != "US";
if (updateOptions == null)
bool setAutomaticTaxToEnabled;
if (setNonUSBusinessUseToReverseCharge)
{
return;
if (nonUSBusinessUse && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse)
{
try
{
await stripeFacade.UpdateCustomer(subscription.CustomerId,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}",
organization.Id,
eventId);
}
}
await stripeFacade.UpdateSubscription(subscription.Id, updateOptions);
return;
setAutomaticTaxToEnabled = true;
}
if (subscription.AutomaticTax.Enabled ||
!subscription.Customer.HasBillingLocation() ||
await IsNonTaxableNonUSBusinessUseSubscription(subscription))
else
{
return;
setAutomaticTaxToEnabled =
subscription.Customer.HasRecognizedTaxLocation() &&
(subscription.Customer.Address.Country == "US" ||
(nonUSBusinessUse && subscription.Customer.TaxIds.Any()));
}
if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled)
{
try
{
await stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
DefaultTaxRates = [],
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
return;
async Task<bool> IsNonTaxableNonUSBusinessUseSubscription(Subscription localSubscription)
}
catch (Exception exception)
{
var familyPriceIds = (await Task.WhenAll(
pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019),
pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually)))
.Select(plan => plan.PasswordManager.StripePlanId);
logger.LogError(
exception,
"Failed to set organization's ({OrganizationID}) subscription to automatic tax while processing event with ID {EventID}",
organization.Id,
eventId);
}
}
}
return localSubscription.Customer.Address.Country != "US" &&
localSubscription.Metadata.ContainsKey(StripeConstants.MetadataKeys.OrganizationId) &&
!localSubscription.Items.Select(item => item.Price.Id).Intersect(familyPriceIds).Any() &&
!localSubscription.Customer.TaxIds.Any();
private async Task AlignProviderTaxConcernsAsync(
Provider provider,
Subscription subscription,
string eventId,
bool setNonUSBusinessUseToReverseCharge)
{
bool setAutomaticTaxToEnabled;
if (setNonUSBusinessUseToReverseCharge)
{
if (subscription.Customer.Address.Country != "US" && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse)
{
try
{
await stripeFacade.UpdateCustomer(subscription.CustomerId,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}",
provider.Id,
eventId);
}
}
setAutomaticTaxToEnabled = true;
}
else
{
setAutomaticTaxToEnabled =
subscription.Customer.HasRecognizedTaxLocation() &&
(subscription.Customer.Address.Country == "US" ||
subscription.Customer.TaxIds.Any());
}
if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled)
{
try
{
await stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
catch (Exception exception)
{
logger.LogError(
exception,
"Failed to set provider's ({ProviderID}) subscription to automatic tax while processing event with ID {EventID}",
provider.Id,
eventId);
}
}
}
}

View File

@ -11,8 +11,6 @@ namespace Bit.Core.Services;
public interface IOrganizationService
{
Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType,
TaxInfo taxInfo);
Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null);
Task ReinstateSubscriptionAsync(Guid organizationId);
Task<string> AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb);

View File

@ -144,27 +144,6 @@ public class OrganizationService : IOrganizationService
_sendOrganizationInvitesCommand = sendOrganizationInvitesCommand;
}
public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken,
PaymentMethodType paymentMethodType, TaxInfo taxInfo)
{
var organization = await GetOrgById(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
await _paymentService.SaveTaxInfoAsync(organization, taxInfo);
var updated = await _paymentService.UpdatePaymentMethodAsync(
organization,
paymentMethodType,
paymentToken,
taxInfo);
if (updated)
{
await ReplaceAndUpdateCacheAsync(organization);
}
}
public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null)
{
var organization = await GetOrgById(organizationId);

View File

@ -2,9 +2,24 @@
public enum EmergencyAccessStatusType : byte
{
/// <summary>
/// The user has been invited to be an emergency contact.
/// </summary>
Invited = 0,
/// <summary>
/// The invited user, "grantee", has accepted the request to be an emergency contact.
/// </summary>
Accepted = 1,
/// <summary>
/// The inviting user, "grantor", has approved the grantee's acceptance.
/// </summary>
Confirmed = 2,
/// <summary>
/// The grantee has initiated the recovery process.
/// </summary>
RecoveryInitiated = 3,
/// <summary>
/// The grantee has excercised their emergency access.
/// </summary>
RecoveryApproved = 4,
}

View File

@ -3,6 +3,7 @@ using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Core.Vault.Models.Data;
@ -20,6 +21,15 @@ public interface IEmergencyAccessService
Task InitiateAsync(Guid id, User initiatingUser);
Task ApproveAsync(Guid id, User approvingUser);
Task RejectAsync(Guid id, User rejectingUser);
/// <summary>
/// This request is made by the Grantee user to fetch the policies <see cref="Policy"/> for the Grantor User.
/// The Grantor User has to be the owner of the organization. <see cref="OrganizationUserType"/>
/// If the Grantor user has OrganizationUserType.Owner then the policies for the _Grantor_ user
/// are returned.
/// </summary>
/// <param name="id">EmergencyAccess.Id being acted on</param>
/// <param name="requestingUser">User making the request, this is the Grantee</param>
/// <returns>null if the GrantorUser is not an organization owner; A list of policies otherwise.</returns>
Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser);
Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser);
Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key);

View File

@ -3,7 +3,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Models;
using Bit.Core.Auth.Models.Business.Tokenables;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities;
@ -16,7 +15,6 @@ using Bit.Core.Tokens;
using Bit.Core.Vault.Models.Data;
using Bit.Core.Vault.Repositories;
using Bit.Core.Vault.Services;
using Microsoft.AspNetCore.Identity;
namespace Bit.Core.Auth.Services;
@ -31,8 +29,6 @@ public class EmergencyAccessService : IEmergencyAccessService
private readonly IMailService _mailService;
private readonly IUserService _userService;
private readonly GlobalSettings _globalSettings;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IOrganizationService _organizationService;
private readonly IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> _dataProtectorTokenizer;
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
@ -45,9 +41,7 @@ public class EmergencyAccessService : IEmergencyAccessService
ICipherService cipherService,
IMailService mailService,
IUserService userService,
IPasswordHasher<User> passwordHasher,
GlobalSettings globalSettings,
IOrganizationService organizationService,
IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> dataProtectorTokenizer,
IRemoveOrganizationUserCommand removeOrganizationUserCommand)
{
@ -59,9 +53,7 @@ public class EmergencyAccessService : IEmergencyAccessService
_cipherService = cipherService;
_mailService = mailService;
_userService = userService;
_passwordHasher = passwordHasher;
_globalSettings = globalSettings;
_organizationService = organizationService;
_dataProtectorTokenizer = dataProtectorTokenizer;
_removeOrganizationUserCommand = removeOrganizationUserCommand;
}
@ -126,7 +118,12 @@ public class EmergencyAccessService : IEmergencyAccessService
throw new BadRequestException("Emergency Access not valid.");
}
if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email))
if (!_dataProtectorTokenizer.TryUnprotect(token, out var data))
{
throw new BadRequestException("Invalid token.");
}
if (!data.IsValid(emergencyAccessId, user.Email))
{
throw new BadRequestException("Invalid token.");
}
@ -140,6 +137,8 @@ public class EmergencyAccessService : IEmergencyAccessService
throw new BadRequestException("Invitation already accepted.");
}
// TODO PM-21687
// Might not be reachable since the Tokenable.IsValid() does an email comparison
if (string.IsNullOrWhiteSpace(emergencyAccess.Email) ||
!emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
{
@ -163,6 +162,8 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
// TODO PM-19438/PM-21687
// Not sure why the GrantorId and the GranteeId are supposed to be the same?
if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId))
{
throw new BadRequestException("Emergency Access not valid.");
@ -171,9 +172,9 @@ public class EmergencyAccessService : IEmergencyAccessService
await _emergencyAccessRepository.DeleteAsync(emergencyAccess);
}
public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId)
public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAccessId, string key, Guid confirmingUserId)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId);
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted ||
emergencyAccess.GrantorId != confirmingUserId)
{
@ -224,7 +225,6 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task InitiateAsync(Guid id, User initiatingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id ||
emergencyAccess.Status != EmergencyAccessStatusType.Confirmed)
{
@ -285,6 +285,9 @@ public class EmergencyAccessService : IEmergencyAccessService
public async Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser)
{
// TODO PM-21687
// Should we look up policies here or just verify the EmergencyAccess is correct
// and handle policy logic else where? Should this be a query/Command?
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
@ -295,7 +298,9 @@ public class EmergencyAccessService : IEmergencyAccessService
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id);
var isOrganizationOwner = grantorOrganizations.Any<OrganizationUser>(organization => organization.Type == OrganizationUserType.Owner);
var isOrganizationOwner = grantorOrganizations
.Any(organization => organization.Type == OrganizationUserType.Owner);
var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null;
return policies;
@ -311,7 +316,8 @@ public class EmergencyAccessService : IEmergencyAccessService
}
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
// TODO PM-21687
// Redundant check of the EmergencyAccessType -> checked in IsValidRequest() ln 308
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
{
throw new BadRequestException("You cannot takeover an account that is using Key Connector.");
@ -336,7 +342,9 @@ public class EmergencyAccessService : IEmergencyAccessService
grantor.LastPasswordChangeDate = grantor.RevisionDate;
grantor.Key = key;
// Disable TwoFactor providers since they will otherwise block logins
grantor.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>());
grantor.SetTwoFactorProviders([]);
// Disable New Device Verification since it will otherwise block logins
grantor.VerifyDevices = false;
await _userRepository.ReplaceAsync(grantor);
// Remove grantor from all organizations unless Owner
@ -421,12 +429,22 @@ public class EmergencyAccessService : IEmergencyAccessService
await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token);
}
private string NameOrEmail(User user)
private static string NameOrEmail(User user)
{
return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name;
}
private bool IsValidRequest(EmergencyAccess availableAccess, User requestingUser, EmergencyAccessType requestedAccessType)
/*
* Checks if EmergencyAccess Object is null
* Checks the requesting user is the same as the granteeUser (So we are checking for proper grantee action)
* Status _must_ equal RecoveryApproved (This means the grantor has invited, the grantee has accepted, and the grantor has approved so the shared key exists but hasn't been exercised yet)
* request type must equal the type of access requested (View or Takeover)
*/
private static bool IsValidRequest(
EmergencyAccess availableAccess,
User requestingUser,
EmergencyAccessType requestedAccessType)
{
return availableAccess != null &&
availableAccess.GranteeId == requestingUser.Id &&

View File

@ -2,10 +2,6 @@
public static class StripeConstants
{
public static class Prices
{
public const string StoragePlanPersonal = "personal-storage-gb-annually";
}
public static class AutomaticTaxStatus
{
public const string Failed = "failed";
@ -69,6 +65,11 @@ public static class StripeConstants
public const string USBankAccount = "us_bank_account";
}
public static class Prices
{
public const string StoragePlanPersonal = "personal-storage-gb-annually";
}
public static class ProrationBehavior
{
public const string AlwaysInvoice = "always_invoice";
@ -88,6 +89,13 @@ public static class StripeConstants
public const string Paused = "paused";
}
public static class TaxExempt
{
public const string Exempt = "exempt";
public const string None = "none";
public const string Reverse = "reverse";
}
public static class ValidateTaxLocationTiming
{
public const string Deferred = "deferred";

View File

@ -15,12 +15,7 @@ public static class CustomerExtensions
}
};
/// <summary>
/// Determines if a Stripe customer supports automatic tax
/// </summary>
/// <param name="customer"></param>
/// <returns></returns>
public static bool HasTaxLocationVerified(this Customer customer) =>
public static bool HasRecognizedTaxLocation(this Customer customer) =>
customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation;
public static decimal GetBillingBalance(this Customer customer)

View File

@ -22,7 +22,7 @@ public static class SubscriptionUpdateOptionsExtensions
}
// We might only need to check the automatic tax status.
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country))
if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
{
return false;
}

View File

@ -22,7 +22,7 @@ public static class UpcomingInvoiceOptionsExtensions
}
// We might only need to check the automatic tax status.
if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country))
if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country))
{
return false;
}

View File

@ -1,11 +1,11 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums;
@ -35,16 +35,15 @@ public class OrganizationBillingService(
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
ITaxService taxService,
IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService
ITaxService taxService) : IOrganizationBillingService
{
public async Task Finalize(OrganizationSale sale)
{
var (organization, customerSetup, subscriptionSetup) = sale;
var customer = string.IsNullOrEmpty(organization.GatewayCustomerId) && customerSetup != null
? await CreateCustomerAsync(organization, customerSetup)
: await subscriberService.GetCustomerOrThrow(organization, new CustomerGetOptions { Expand = ["tax", "tax_ids"] });
? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType)
: await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup);
var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup);
@ -121,7 +120,8 @@ public class OrganizationBillingService(
subscription.CurrentPeriodEnd);
}
public async Task UpdatePaymentMethod(
public async Task
UpdatePaymentMethod(
Organization organization,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation)
@ -151,8 +151,11 @@ public class OrganizationBillingService(
private async Task<Customer> CreateCustomerAsync(
Organization organization,
CustomerSetup customerSetup)
CustomerSetup customerSetup,
PlanType? updatedPlanType = null)
{
var planType = updatedPlanType ?? organization.PlanType;
var displayName = organization.DisplayName();
var customerCreateOptions = new CustomerCreateOptions
@ -212,13 +215,24 @@ public class OrganizationBillingService(
City = customerSetup.TaxInformation.City,
PostalCode = customerSetup.TaxInformation.PostalCode,
State = customerSetup.TaxInformation.State,
Country = customerSetup.TaxInformation.Country,
Country = customerSetup.TaxInformation.Country
};
customerCreateOptions.Tax = new CustomerTaxOptions
{
ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately
};
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge &&
planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families &&
customerSetup.TaxInformation.Country != "US")
{
customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse;
}
if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId))
{
var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country,
@ -399,21 +413,68 @@ public class OrganizationBillingService(
TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays
};
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge)
{
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType);
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
}
else
else if (customer.HasRecognizedTaxLocation())
{
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions();
subscriptionCreateOptions.AutomaticTax.Enabled = customer.HasBillingLocation();
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled =
subscriptionSetup.PlanType.GetProductTier() == ProductTierType.Families ||
customer.Address.Country == "US" ||
customer.TaxIds.Any()
};
}
return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
}
private async Task<Customer> GetCustomerWhileEnsuringCorrectTaxExemptionAsync(
Organization organization,
SubscriptionSetup subscriptionSetup)
{
var customer = await subscriberService.GetCustomerOrThrow(organization,
new CustomerGetOptions { Expand = ["tax", "tax_ids"] });
var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (!setNonUSBusinessUseToReverseCharge || subscriptionSetup.PlanType.GetProductTier() is
not (ProductTierType.Teams or
ProductTierType.TeamsStarter or
ProductTierType.Enterprise))
{
return customer;
}
List<string> expansions = ["tax", "tax_ids"];
customer = customer switch
{
{ Address.Country: not "US", TaxExempt: not StripeConstants.TaxExempt.Reverse } => await
stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions
{
Expand = expansions,
TaxExempt = StripeConstants.TaxExempt.Reverse
}),
{ Address.Country: "US", TaxExempt: StripeConstants.TaxExempt.Reverse } => await
stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions
{
Expand = expansions,
TaxExempt = StripeConstants.TaxExempt.None
}),
_ => customer
};
return customer;
}
private async Task<bool> IsEligibleForSelfHostAsync(
Organization organization)
{

View File

@ -3,8 +3,6 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Billing.Tax.Services.Implementations;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
@ -12,7 +10,6 @@ using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Braintree;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Stripe;
using Customer = Stripe.Customer;
@ -24,20 +21,18 @@ using static Utilities;
public class PremiumUserBillingService(
IBraintreeGateway braintreeGateway,
IFeatureService featureService,
IGlobalSettings globalSettings,
ILogger<PremiumUserBillingService> logger,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
IUserRepository userRepository,
[FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService
IUserRepository userRepository) : IPremiumUserBillingService
{
public async Task Credit(User user, decimal amount)
{
var customer = await subscriberService.GetCustomer(user);
// Negative credit represents a balance and all Stripe denomination is in cents.
// Negative credit represents a balance, and all Stripe denomination is in cents.
var credit = (long)(amount * -100);
if (customer == null)
@ -184,7 +179,7 @@ public class PremiumUserBillingService(
City = customerSetup.TaxInformation.City,
PostalCode = customerSetup.TaxInformation.PostalCode,
State = customerSetup.TaxInformation.State,
Country = customerSetup.TaxInformation.Country,
Country = customerSetup.TaxInformation.Country
},
Description = user.Name,
Email = user.Email,
@ -324,6 +319,10 @@ public class PremiumUserBillingService(
var subscriptionCreateOptions = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically,
Customer = customer.Id,
Items = subscriptionItemOptionsList,
@ -337,18 +336,6 @@ public class PremiumUserBillingService(
OffSession = true
};
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported,
};
}
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
if (usingPayPal)
@ -380,7 +367,7 @@ public class PremiumUserBillingService(
City = taxInformation.City,
PostalCode = taxInformation.PostalCode,
State = taxInformation.State,
Country = taxInformation.Country,
Country = taxInformation.Country
},
Expand = ["tax"],
Tax = new CustomerTaxOptions

View File

@ -1,7 +1,10 @@
using Bit.Core.Billing.Caches;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Entities;
@ -28,8 +31,7 @@ public class SubscriberService(
ILogger<SubscriberService> logger,
ISetupIntentCache setupIntentCache,
IStripeAdapter stripeAdapter,
ITaxService taxService,
IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService
ITaxService taxService) : ISubscriberService
{
public async Task CancelSubscription(
ISubscriber subscriber,
@ -128,7 +130,7 @@ public class SubscriberService(
[subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion
},
Email = subscriber.BillingEmailAddress(),
PaymentMethodNonce = paymentMethodNonce,
PaymentMethodNonce = paymentMethodNonce
});
if (customerResult.IsSuccess())
@ -482,7 +484,7 @@ public class SubscriberService(
var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First();
// Find the customer's existing setup intents that should be cancelled.
// Find the customer's existing setup intents that should be canceled.
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
.Where(si =>
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
@ -519,7 +521,7 @@ public class SubscriberService(
await stripeAdapter.PaymentMethodAttachAsync(token,
new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId });
// Find the customer's existing setup intents that should be cancelled.
// Find the customer's existing setup intents that should be canceled.
var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer)
.Where(si =>
si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action");
@ -637,7 +639,8 @@ public class SubscriberService(
logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
taxInformation.Country,
taxInformation.TaxId);
throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError");
throw new BadRequestException("billingTaxIdTypeInferenceError");
}
}
@ -654,53 +657,84 @@ public class SubscriberService(
logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.",
taxInformation.TaxId,
taxInformation.Country);
throw new Exceptions.BadRequestException("billingInvalidTaxIdError");
throw new BadRequestException("billingInvalidTaxIdError");
default:
logger.LogError(e,
"Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.",
taxInformation.TaxId,
taxInformation.Country,
customer.Id);
throw new Exceptions.BadRequestException("billingTaxIdCreationError");
throw new BadRequestException("billingTaxIdCreationError");
}
}
}
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
var subscription =
customer.Subscriptions.First(subscription => subscription.Id == subscriber.GatewaySubscriptionId);
var isBusinessUseSubscriber = subscriber switch
{
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
{
var subscriptionGetOptions = new SubscriptionGetOptions
{
Expand = ["customer.tax", "customer.tax_ids"]
Organization organization => organization.PlanType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families,
Provider => true,
_ => false
};
var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions);
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id));
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
if (automaticTaxOptions?.AutomaticTax?.Enabled != null)
var setNonUSBusinessUseToReverseCharge =
featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge && isBusinessUseSubscriber)
{
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions);
}
}
}
else
switch (customer)
{
if (SubscriberIsEligibleForAutomaticTax(subscriber, customer))
case
{
await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId,
Address.Country: not "US",
TaxExempt: not StripeConstants.TaxExempt.Reverse
}:
await stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
break;
case
{
Address.Country: "US",
TaxExempt: StripeConstants.TaxExempt.Reverse
}:
await stripeAdapter.CustomerUpdateAsync(customer.Id,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.None });
break;
}
if (!subscription.AutomaticTax.Enabled)
{
await stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
}
else
{
var automaticTaxShouldBeEnabled = subscriber switch
{
User => true,
Organization organization => organization.PlanType.GetProductTier() == ProductTierType.Families ||
customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false),
Provider => customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false),
_ => false
};
return;
bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer)
=> !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) &&
(localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) &&
localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported;
if (automaticTaxShouldBeEnabled && !subscription.AutomaticTax.Enabled)
{
await stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
}
}

View File

@ -76,7 +76,7 @@ public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : I
private bool ShouldBeEnabled(Customer customer)
{
if (!customer.HasTaxLocationVerified())
if (!customer.HasRecognizedTaxLocation())
{
return false;
}

View File

@ -59,6 +59,6 @@ public class PersonalUseAutomaticTaxStrategy(IFeatureService featureService) : I
private static bool ShouldBeEnabled(Customer customer)
{
return customer.HasTaxLocationVerified();
return customer.HasRecognizedTaxLocation();
}
}

View File

@ -143,13 +143,13 @@ public static class FeatureFlagKeys
public const string UsePricingService = "use-pricing-service";
public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features";
public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method";
public const string PM19147_AutomaticTaxImprovements = "pm-19147-automatic-tax-improvements";
public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates";
public const string PM18770_EnableOrganizationBusinessUnitConversion = "pm-18770-enable-organization-business-unit-conversion";
public const string PM199566_UpdateMSPToChargeAutomatically = "pm-199566-update-msp-to-charge-automatically";
public const string PM19956_RequireProviderPaymentMethodDuringSetup = "pm-19956-require-provider-payment-method-during-setup";
public const string UseOrganizationWarningsService = "use-organization-warnings-service";
public const string PM20322_AllowTrialLength0 = "pm-20322-allow-trial-length-0";
public const string PM21092_SetNonUSBusinessUseToReverseCharge = "pm-21092-set-non-us-business-use-to-reverse-charge";
/* Data Insights and Reporting Team */
public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application";

View File

@ -255,7 +255,10 @@ public class OrganizationLicense : ILicense
!p.Name.Equals(nameof(Refresh))
)
) &&
!p.Name.Equals(nameof(UseRiskInsights)))
// any new fields added need to be added here so that they're ignored
!p.Name.Equals(nameof(UseRiskInsights)) &&
!p.Name.Equals(nameof(UseAdminSponsoredFamilies)) &&
!p.Name.Equals(nameof(UseOrganizationDomains)))
.OrderBy(p => p.Name)
.Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}")
.Aggregate((c, n) => $"{c}|{n}");

View File

@ -4,7 +4,6 @@ using Bit.Core.Billing.Models;
using Bit.Core.Billing.Tax.Requests;
using Bit.Core.Billing.Tax.Responses;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Business;
using Bit.Core.Models.StaticStore;
@ -30,8 +29,6 @@ public interface IPaymentService
Task<string> AdjustServiceAccountsAsync(Organization organization, Plan plan, int additionalServiceAccounts);
Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false);
Task ReinstateSubscriptionAsync(ISubscriber subscriber);
Task<bool> UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType,
string paymentToken, TaxInfo taxInfo = null);
Task<bool> CreditAccountAsync(ISubscriber subscriber, decimal creditAmount);
Task<BillingInfo> GetBillingAsync(ISubscriber subscriber);
Task<BillingHistoryInfo> GetBillingHistoryAsync(ISubscriber subscriber);

View File

@ -1,13 +1,13 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Requests;
using Bit.Core.Billing.Tax.Responses;
using Bit.Core.Billing.Tax.Services;
@ -38,7 +38,6 @@ public class StripePaymentService : IPaymentService
private readonly IGlobalSettings _globalSettings;
private readonly IFeatureService _featureService;
private readonly ITaxService _taxService;
private readonly ISubscriberService _subscriberService;
private readonly IPricingClient _pricingClient;
private readonly IAutomaticTaxFactory _automaticTaxFactory;
private readonly IAutomaticTaxStrategy _personalUseTaxStrategy;
@ -51,7 +50,6 @@ public class StripePaymentService : IPaymentService
IGlobalSettings globalSettings,
IFeatureService featureService,
ITaxService taxService,
ISubscriberService subscriberService,
IPricingClient pricingClient,
IAutomaticTaxFactory automaticTaxFactory,
[FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy personalUseTaxStrategy)
@ -63,7 +61,6 @@ public class StripePaymentService : IPaymentService
_globalSettings = globalSettings;
_featureService = featureService;
_taxService = taxService;
_subscriberService = subscriberService;
_pricingClient = pricingClient;
_automaticTaxFactory = automaticTaxFactory;
_personalUseTaxStrategy = personalUseTaxStrategy;
@ -136,15 +133,68 @@ public class StripePaymentService : IPaymentService
if (subscriptionUpdate is CompleteSubscriptionUpdate)
{
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
var setNonUSBusinessUseToReverseCharge =
_featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge);
if (setNonUSBusinessUseToReverseCharge)
{
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, updatedItemOptions.Select(x => x.Plan ?? x.Price));
var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters);
automaticTaxStrategy.SetUpdateOptions(subUpdateOptions, sub);
if (sub.Customer is
{
Address.Country: not "US",
TaxExempt: not StripeConstants.TaxExempt.Reverse
})
{
await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId,
new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse });
}
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
}
else if (sub.Customer.HasRecognizedTaxLocation())
{
switch (subscriber)
{
case User:
{
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
break;
}
case Organization:
{
if (sub.Customer.Address.Country == "US")
{
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
}
else
{
subUpdateOptions.EnableAutomaticTax(sub.Customer, sub);
var familyPriceIds = (await Task.WhenAll(
_pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019),
_pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually)))
.Select(plan => plan.PasswordManager.StripePlanId);
var updateIsForPersonalUse = updatedItemOptions
.Select(option => option.Price)
.Intersect(familyPriceIds)
.Any();
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = updateIsForPersonalUse || sub.Customer.TaxIds.Any()
};
}
break;
}
case Provider:
{
subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = sub.Customer.Address.Country == "US" ||
sub.Customer.TaxIds.Any()
};
break;
}
}
}
}
@ -202,7 +252,7 @@ public class StripePaymentService : IPaymentService
}
else if (!invoice.Paid)
{
// Pay invoice with no charge to customer this completes the invoice immediately without waiting the scheduled 1h
// Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h
invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId);
paymentIntentClientSecret = null;
}
@ -585,309 +635,6 @@ public class StripePaymentService : IPaymentService
}
}
public async Task<bool> UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType,
string paymentToken, TaxInfo taxInfo = null)
{
if (subscriber == null)
{
throw new ArgumentNullException(nameof(subscriber));
}
if (subscriber.Gateway.HasValue && subscriber.Gateway.Value != GatewayType.Stripe)
{
throw new GatewayException("Switching from one payment type to another is not supported. " +
"Contact us for assistance.");
}
var createdCustomer = false;
Braintree.Customer braintreeCustomer = null;
string stipeCustomerSourceToken = null;
string stipeCustomerPaymentMethodId = null;
var stripeCustomerMetadata = new Dictionary<string, string>
{
{ "region", _globalSettings.BaseServiceUri.CloudRegion }
};
var stripePaymentMethod = paymentMethodType is PaymentMethodType.Card or PaymentMethodType.BankAccount;
Customer customer = null;
if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId))
{
var options = new CustomerGetOptions { Expand = ["sources", "tax", "subscriptions"] };
customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, options);
if (customer.Metadata?.Any() ?? false)
{
stripeCustomerMetadata = customer.Metadata;
}
}
var hadBtCustomer = stripeCustomerMetadata.ContainsKey("btCustomerId");
if (stripePaymentMethod)
{
if (paymentToken.StartsWith("pm_"))
{
stipeCustomerPaymentMethodId = paymentToken;
}
else
{
stipeCustomerSourceToken = paymentToken;
}
}
else if (paymentMethodType == PaymentMethodType.PayPal)
{
if (hadBtCustomer)
{
var pmResult = await _btGateway.PaymentMethod.CreateAsync(new Braintree.PaymentMethodRequest
{
CustomerId = stripeCustomerMetadata["btCustomerId"],
PaymentMethodNonce = paymentToken
});
if (pmResult.IsSuccess())
{
var customerResult = await _btGateway.Customer.UpdateAsync(
stripeCustomerMetadata["btCustomerId"], new Braintree.CustomerRequest
{
DefaultPaymentMethodToken = pmResult.Target.Token
});
if (customerResult.IsSuccess() && customerResult.Target.PaymentMethods.Length > 0)
{
braintreeCustomer = customerResult.Target;
}
else
{
await _btGateway.PaymentMethod.DeleteAsync(pmResult.Target.Token);
hadBtCustomer = false;
}
}
else
{
hadBtCustomer = false;
}
}
if (!hadBtCustomer)
{
var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest
{
PaymentMethodNonce = paymentToken,
Email = subscriber.BillingEmailAddress(),
Id = subscriber.BraintreeCustomerIdPrefix() + subscriber.Id.ToString("N").ToLower() +
Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false),
CustomFields = new Dictionary<string, string>
{
[subscriber.BraintreeIdField()] = subscriber.Id.ToString(),
[subscriber.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion
}
});
if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0)
{
throw new GatewayException("Failed to create PayPal customer record.");
}
braintreeCustomer = customerResult.Target;
}
}
else
{
throw new GatewayException("Payment method is not supported at this time.");
}
if (stripeCustomerMetadata.ContainsKey("btCustomerId"))
{
if (braintreeCustomer?.Id != stripeCustomerMetadata["btCustomerId"])
{
stripeCustomerMetadata["btCustomerId_old"] = stripeCustomerMetadata["btCustomerId"];
}
stripeCustomerMetadata["btCustomerId"] = braintreeCustomer?.Id;
}
else if (!string.IsNullOrWhiteSpace(braintreeCustomer?.Id))
{
stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id);
}
try
{
if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber))
{
taxInfo.TaxIdType = taxInfo.TaxIdType ??
_taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber);
}
if (customer == null)
{
customer = await _stripeAdapter.CustomerCreateAsync(new CustomerCreateOptions
{
Description = subscriber.BillingName(),
Email = subscriber.BillingEmailAddress(),
Metadata = stripeCustomerMetadata,
Source = stipeCustomerSourceToken,
PaymentMethod = stipeCustomerPaymentMethodId,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
DefaultPaymentMethod = stipeCustomerPaymentMethodId,
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions()
{
Name = subscriber.SubscriberType(),
Value = subscriber.GetFormattedInvoiceName()
}
]
},
Address = taxInfo == null ? null : new AddressOptions
{
Country = taxInfo.BillingAddressCountry,
PostalCode = taxInfo.BillingAddressPostalCode,
Line1 = taxInfo.BillingAddressLine1 ?? string.Empty,
Line2 = taxInfo.BillingAddressLine2,
City = taxInfo.BillingAddressCity,
State = taxInfo.BillingAddressState
},
TaxIdData = string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber)
? []
: [
new CustomerTaxIdDataOptions
{
Type = taxInfo.TaxIdType,
Value = taxInfo.TaxIdNumber
}
],
Expand = ["sources", "tax", "subscriptions"],
});
subscriber.Gateway = GatewayType.Stripe;
subscriber.GatewayCustomerId = customer.Id;
createdCustomer = true;
}
if (!createdCustomer)
{
string defaultSourceId = null;
string defaultPaymentMethodId = null;
if (stripePaymentMethod)
{
if (!string.IsNullOrWhiteSpace(stipeCustomerSourceToken) && paymentToken.StartsWith("btok_"))
{
var bankAccount = await _stripeAdapter.BankAccountCreateAsync(customer.Id, new BankAccountCreateOptions
{
Source = paymentToken
});
defaultSourceId = bankAccount.Id;
}
else if (!string.IsNullOrWhiteSpace(stipeCustomerPaymentMethodId))
{
await _stripeAdapter.PaymentMethodAttachAsync(stipeCustomerPaymentMethodId,
new PaymentMethodAttachOptions { Customer = customer.Id });
defaultPaymentMethodId = stipeCustomerPaymentMethodId;
}
}
if (customer.Sources != null)
{
foreach (var source in customer.Sources.Where(s => s.Id != defaultSourceId))
{
if (source is BankAccount)
{
await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id);
}
else if (source is Card)
{
await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id);
}
}
}
var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging(new PaymentMethodListOptions
{
Customer = customer.Id,
Type = "card"
});
foreach (var cardMethod in cardPaymentMethods.Where(m => m.Id != defaultPaymentMethodId))
{
await _stripeAdapter.PaymentMethodDetachAsync(cardMethod.Id, new PaymentMethodDetachOptions());
}
await _subscriberService.UpdateTaxInformation(subscriber, TaxInformation.From(taxInfo));
customer = await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
Metadata = stripeCustomerMetadata,
DefaultSource = defaultSourceId,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
DefaultPaymentMethod = defaultPaymentMethodId,
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions()
{
Name = subscriber.SubscriberType(),
Value = subscriber.GetFormattedInvoiceName()
}
]
},
Expand = ["tax", "subscriptions"]
});
}
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
{
var subscriptionGetOptions = new SubscriptionGetOptions
{
Expand = ["customer.tax", "customer.tax_ids"]
};
var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions);
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id));
var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters);
var subscriptionUpdateOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
if (subscriptionUpdateOptions != null)
{
_ = await _stripeAdapter.SubscriptionUpdateAsync(
subscriber.GatewaySubscriptionId,
subscriptionUpdateOptions);
}
}
}
else
{
if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) &&
customer.Subscriptions.Any(sub =>
sub.Id == subscriber.GatewaySubscriptionId &&
!sub.AutomaticTax.Enabled) &&
customer.HasTaxLocationVerified())
{
var subscriptionUpdateOptions = new SubscriptionUpdateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
DefaultTaxRates = []
};
_ = await _stripeAdapter.SubscriptionUpdateAsync(
subscriber.GatewaySubscriptionId,
subscriptionUpdateOptions);
}
}
}
catch
{
if (braintreeCustomer != null && !hadBtCustomer)
{
await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id);
}
throw;
}
return createdCustomer;
}
public async Task<bool> CreditAccountAsync(ISubscriber subscriber, decimal creditAmount)
{
Customer customer = null;
@ -1018,7 +765,7 @@ public class StripePaymentService : IPaymentService
var address = customer.Address;
var taxId = customer.TaxIds?.FirstOrDefault();
// Line1 is required, so if missing we're using the subscriber name
// Line1 is required, so if missing we're using the subscriber name,
// see: https://stripe.com/docs/api/customers/create#create_customer-address-line1
if (address != null && string.IsNullOrWhiteSpace(address.Line1))
{

View File

@ -0,0 +1,19 @@
using Bit.Core.Tools.Entities;
namespace Bit.Core.Tools.Models.Data;
/// <summary>
/// This enum represents the possible results when attempting to access a <see cref="Send"/>.
/// </summary>
/// <member>name="Granted">Access is granted for the <see cref="Send"/>.</member>
/// <member>name="PasswordRequired">Access is denied, but a password is required to access the <see cref="Send"/>.
/// </member>
/// <member>name="PasswordInvalid">Access is denied due to an invalid password.</member>
/// <member>name="Denied">Access is denied for the <see cref="Send"/>.</member>
public enum SendAccessResult
{
Granted,
PasswordRequired,
PasswordInvalid,
Denied
}

View File

@ -0,0 +1,52 @@
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services;
namespace Bit.Core.Tools.SendFeatures.Commands;
public class AnonymousSendCommand : IAnonymousSendCommand
{
private readonly ISendRepository _sendRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPushNotificationService _pushNotificationService;
private readonly ISendAuthorizationService _sendAuthorizationService;
public AnonymousSendCommand(
ISendRepository sendRepository,
ISendFileStorageService sendFileStorageService,
IPushNotificationService pushNotificationService,
ISendAuthorizationService sendAuthorizationService
)
{
_sendRepository = sendRepository;
_sendFileStorageService = sendFileStorageService;
_pushNotificationService = pushNotificationService;
_sendAuthorizationService = sendAuthorizationService;
}
// Response: Send, password required, password invalid
public async Task<(string, SendAccessResult)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Can only get a download URL for a file type of Send");
}
var result = _sendAuthorizationService.SendCanBeAccessed(send, password);
if (!result.Equals(SendAccessResult.Granted))
{
return (null, result);
}
send.AccessCount++;
await _sendRepository.ReplaceAsync(send);
await _pushNotificationService.PushSyncSendUpdateAsync(send);
return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), result);
}
}

View File

@ -0,0 +1,21 @@
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Models.Data;
namespace Bit.Core.Tools.SendFeatures.Commands.Interfaces;
/// <summary>
/// AnonymousSendCommand interface provides methods for managing anonymous Sends.
/// </summary>
public interface IAnonymousSendCommand
{
/// <summary>
/// Gets the Send file download URL for a Send object.
/// </summary>
/// <param name="send"><see cref="Send" /> used to help get file download url and validate file</param>
/// <param name="fileId">FileId get file download url</param>
/// <param name="password">A hashed and base64-encoded password. This is compared with the send's password to authorize access.</param>
/// <returns>Async Task object with Tuple containing the string of download url and <see cref="SendAccessResult" />
/// to determine if the user can access send.
/// </returns>
Task<(string, SendAccessResult)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password);
}

View File

@ -0,0 +1,53 @@
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Models.Data;
namespace Bit.Core.Tools.SendFeatures.Commands.Interfaces;
/// <summary>
/// NonAnonymousSendCommand interface provides methods for managing non-anonymous Sends.
/// </summary>
public interface INonAnonymousSendCommand
{
/// <summary>
/// Saves a <see cref="Send" /> to the database.
/// </summary>
/// <param name="send"><see cref="Send" /> that will save to database</param>
/// <returns>Task completes as <see cref="Send" /> saves to the database</returns>
Task SaveSendAsync(Send send);
/// <summary>
/// Saves the <see cref="Send" /> and <see cref="SendFileData" /> to the database.
/// </summary>
/// <param name="send"><see cref="Send" /> that will save to the database</param>
/// <param name="data"><see cref="SendFileData" /> that will save to file storage</param>
/// <param name="fileLength">Length of file help with saving to file storage</param>
/// <returns>Task object for async operations with file upload url</returns>
Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength);
/// <summary>
/// Upload a file to an existing <see cref="Send" />.
/// </summary>
/// <param name="stream"><see cref="Stream" /> of file to be uploaded. The <see cref="Stream" /> position
/// will be set to 0 before uploading the file.</param>
/// <param name="send"><see cref="Send" /> used to help with uploading file</param>
/// <returns>Task completes after saving <see cref="Stream" /> and <see cref="Send" /> metadata to the file storage</returns>
Task UploadFileToExistingSendAsync(Stream stream, Send send);
/// <summary>
/// Deletes a <see cref="Send" /> from the database and file storage.
/// </summary>
/// <param name="send"><see cref="Send" /> is used to delete from database and file storage</param>
/// <returns>Task completes once <see cref="Send" /> has been deleted from database and file storage.</returns>
Task DeleteSendAsync(Send send);
/// <summary>
/// Stores the confirmed file size of a send; when the file size cannot be confirmed, the send is deleted.
/// </summary>
/// <param name="send">The <see cref="Send" /> this command acts upon</param>
/// <returns><see langword="true" /> when the file is confirmed, otherwise <see langword="false" /></returns>
/// <remarks>
/// When a file size cannot be confirmed, we assume we're working with a rogue client. The send is deleted out of
/// an abundance of caution.
/// </remarks>
Task<bool> ConfirmFileSize(Send send);
}

View File

@ -0,0 +1,180 @@
using System.Text.Json;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities;
namespace Bit.Core.Tools.SendFeatures.Commands;
public class NonAnonymousSendCommand : INonAnonymousSendCommand
{
private readonly ISendRepository _sendRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPushNotificationService _pushNotificationService;
private readonly ISendValidationService _sendValidationService;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
private readonly ISendCoreHelperService _sendCoreHelperService;
public NonAnonymousSendCommand(ISendRepository sendRepository,
ISendFileStorageService sendFileStorageService,
IPushNotificationService pushNotificationService,
ISendAuthorizationService sendAuthorizationService,
ISendValidationService sendValidationService,
IReferenceEventService referenceEventService,
ICurrentContext currentContext,
ISendCoreHelperService sendCoreHelperService)
{
_sendRepository = sendRepository;
_sendFileStorageService = sendFileStorageService;
_pushNotificationService = pushNotificationService;
_sendValidationService = sendValidationService;
_referenceEventService = referenceEventService;
_currentContext = currentContext;
_sendCoreHelperService = sendCoreHelperService;
}
public async Task SaveSendAsync(Send send)
{
// Make sure user can save Sends
await _sendValidationService.ValidateUserCanSaveAsync(send.UserId, send);
if (send.Id == default(Guid))
{
await _sendRepository.CreateAsync(send);
await _pushNotificationService.PushSyncSendCreateAsync(send);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
{
Id = send.UserId ?? default,
Type = ReferenceEventType.SendCreated,
Source = ReferenceEventSource.User,
SendType = send.Type,
MaxAccessCount = send.MaxAccessCount,
HasPassword = !string.IsNullOrWhiteSpace(send.Password),
SendHasNotes = send.Data?.Contains("Notes"),
ClientId = _currentContext.ClientId,
ClientVersion = _currentContext.ClientVersion
});
}
else
{
send.RevisionDate = DateTime.UtcNow;
await _sendRepository.UpsertAsync(send);
await _pushNotificationService.PushSyncSendUpdateAsync(send);
}
}
public async Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Send is not of type \"file\".");
}
if (fileLength < 1)
{
throw new BadRequestException("No file data.");
}
var storageBytesRemaining = await _sendValidationService.StorageRemainingForSendAsync(send);
if (storageBytesRemaining < fileLength)
{
throw new BadRequestException("Not enough storage available.");
}
var fileId = _sendCoreHelperService.SecureRandomString(32, useUpperCase: false, useSpecial: false);
try
{
data.Id = fileId;
data.Size = fileLength;
data.Validated = false;
send.Data = JsonSerializer.Serialize(data,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId);
}
catch
{
// Clean up since this is not transactional
await _sendFileStorageService.DeleteFileAsync(send, fileId);
throw;
}
}
public async Task UploadFileToExistingSendAsync(Stream stream, Send send)
{
if (stream.Position > 0)
{
stream.Position = 0;
}
if (send?.Data == null)
{
throw new BadRequestException("Send does not have file data");
}
if (send.Type != SendType.File)
{
throw new BadRequestException("Not a File Type Send.");
}
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
if (data.Validated)
{
throw new BadRequestException("File has already been uploaded.");
}
await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id);
if (!await ConfirmFileSize(send))
{
throw new BadRequestException("File received does not match expected file length.");
}
}
public async Task DeleteSendAsync(Send send)
{
await _sendRepository.DeleteAsync(send);
if (send.Type == Enums.SendType.File)
{
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
await _sendFileStorageService.DeleteFileAsync(send, data.Id);
}
await _pushNotificationService.PushSyncSendDeleteAsync(send);
}
public async Task<bool> ConfirmFileSize(Send send)
{
var fileData = JsonSerializer.Deserialize<SendFileData>(send.Data);
var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, SendFileSettingHelper.FILE_SIZE_LEEWAY);
if (!valid || realSize > SendFileSettingHelper.FILE_SIZE_LEEWAY)
{
// File reported differs in size from that promised. Must be a rogue client. Delete Send
await DeleteSendAsync(send);
return false;
}
// Update Send data if necessary
if (realSize != fileData.Size)
{
fileData.Size = realSize.Value;
}
fileData.Validated = true;
send.Data = JsonSerializer.Serialize(fileData,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return valid;
}
}

View File

@ -0,0 +1,18 @@
using Bit.Core.Tools.SendFeatures.Commands;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services;
using Microsoft.Extensions.DependencyInjection;
namespace Bit.Core.Tools.SendFeatures;
public static class SendServiceCollectionExtension
{
public static void AddSendServices(this IServiceCollection services)
{
services.AddScoped<INonAnonymousSendCommand, NonAnonymousSendCommand>();
services.AddScoped<IAnonymousSendCommand, AnonymousSendCommand>();
services.AddScoped<ISendAuthorizationService, SendAuthorizationService>();
services.AddScoped<ISendValidationService, SendValidationService>();
services.AddScoped<ISendCoreHelperService, SendCoreHelperService>();
}
}

View File

@ -0,0 +1,28 @@
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Models.Data;
namespace Bit.Core.Tools.Services;
/// <summary>
/// Send Authorization service is responsible for checking if a Send can be accessed.
/// </summary>
public interface ISendAuthorizationService
{
/// <summary>
/// Checks if a <see cref="Send" /> can be accessed while updating the <see cref="Send" />, pushing a notification, and sending a reference event.
/// </summary>
/// <param name="send"><see cref="Send" /> used to determine access</param>
/// <param name="password">A hashed and base64-encoded password. This is compared with the send's password to authorize access.</param>
/// <returns><see cref="SendAccessResult" /> will be returned to determine if the user can access send.
/// </returns>
Task<SendAccessResult> AccessAsync(Send send, string password);
SendAccessResult SendCanBeAccessed(Send send,
string password);
/// <summary>
/// Hashes the password using the password hasher.
/// </summary>
/// <param name="password">Password to be hashed</param>
/// <returns>Hashed password of the password given</returns>
string HashPassword(string password);
}

View File

@ -0,0 +1,17 @@
namespace Bit.Core.Tools.Services;
/// <summary>
/// This interface provides helper methods for generating secure random strings. Making
/// it easier to mock the service in unit tests.
/// </summary>
public interface ISendCoreHelperService
{
/// <summary>
/// Securely generates a random string of the specified length.
/// </summary>
/// <param name="length">Desired string length to be returned</param>
/// <param name="useUpperCase">Desired casing for the string</param>
/// <param name="useSpecial">Determines if special characters will be used in string</param>
/// <returns>A secure random string with the desired parameters</returns>
string SecureRandomString(int length, bool useUpperCase, bool useSpecial);
}

View File

@ -0,0 +1,71 @@
using Bit.Core.Enums;
using Bit.Core.Tools.Entities;
namespace Bit.Core.Tools.Services;
/// <summary>
/// Send File Storage Service is responsible for uploading, deleting, and validating files
/// whether they are in local storage or in cloud storage.
/// </summary>
public interface ISendFileStorageService
{
FileUploadType FileUploadType { get; }
/// <summary>
/// Uploads a new file to the storage.
/// </summary>
/// <param name="stream"><see cref="Stream" /> of the file</param>
/// <param name="send"><see cref="Send" /> for the file</param>
/// <param name="fileId">File id</param>
/// <returns>Task completes once <see cref="Stream" /> and <see cref="Send" /> have been saved to the database</returns>
Task UploadNewFileAsync(Stream stream, Send send, string fileId);
/// <summary>
/// Deletes a file from the storage.
/// </summary>
/// <param name="send"><see cref="Send" /> used to delete file</param>
/// <param name="fileId">File id of file to be deleted</param>
/// <returns>Task completes once <see cref="Send" /> has been deleted to the database</returns>
Task DeleteFileAsync(Send send, string fileId);
/// <summary>
/// Deletes all files for a specific organization.
/// </summary>
/// <param name="organizationId"><see cref="Guid" /> used to delete all files pertaining to organization</param>
/// <returns>Task completes after running code to delete files by organization id</returns>
Task DeleteFilesForOrganizationAsync(Guid organizationId);
/// <summary>
/// Deletes all files for a specific user.
/// </summary>
/// <param name="userId"><see cref="Guid" /> used to delete all files pertaining to user</param>
/// <returns>Task completes after running code to delete files by user id</returns>
Task DeleteFilesForUserAsync(Guid userId);
/// <summary>
/// Gets the download URL for a file.
/// </summary>
/// <param name="send"><see cref="Send" /> used to help get download url for file</param>
/// <param name="fileId">File id to help get download url for file</param>
/// <returns>Download url as a string</returns>
Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId);
/// <summary>
/// Gets the upload URL for a file.
/// </summary>
/// <param name="send"><see cref="Send" /> used to help get upload url for file </param>
/// <param name="fileId">File id to help get upload url for file</param>
/// <returns>File upload url as string</returns>
Task<string> GetSendFileUploadUrlAsync(Send send, string fileId);
/// <summary>
/// Validates the file size of a file in the storage.
/// </summary>
/// <param name="send"><see cref="Send" /> used to help validate file</param>
/// <param name="fileId">File id to identify which file to validate</param>
/// <param name="expectedFileSize">Expected file size of the file</param>
/// <param name="leeway">
/// Send file size tolerance in bytes. When an uploaded file's `expectedFileSize`
/// is outside of the leeway, the storage operation fails.
/// </param>
/// <throws>
/// ❌ Fill this in with an explanation of the error thrown when `leeway` is incorrect
/// </throws>
/// <returns>Task object for async operations with Tuple of boolean that determines if file was valid and long that
/// the actual file size of the file.
/// </returns>
Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway);
}

View File

@ -0,0 +1,35 @@
using Bit.Core.Tools.Entities;
namespace Bit.Core.Tools.Services;
public interface ISendValidationService
{
/// <summary>
/// Validates a file can be saved by specified user.
/// </summary>
/// <param name="userId"><see cref="Guid" /> needed to validate file for specific user</param>
/// <param name="send"><see cref="Send" /> needed to help validate file</param>
/// <returns>Task completes when a conditional statement has been met it will return out of the method or
/// throw a BadRequestException.
/// </returns>
Task ValidateUserCanSaveAsync(Guid? userId, Send send);
/// <summary>
/// Validates a file can be saved by specified user with different policy based on feature flag
/// </summary>
/// <param name="userId"><see cref="Guid" /> needed to validate file for specific user</param>
/// <param name="send"><see cref="Send" /> needed to help validate file</param>
/// <returns>Task completes when a conditional statement has been met it will return out of the method or
/// throw a BadRequestException.
/// </returns>
Task ValidateUserCanSaveAsync_vNext(Guid? userId, Send send);
/// <summary>
/// Calculates the remaining storage for a Send.
/// </summary>
/// <param name="send"><see cref="Send" /> needed to help calculate remaining storage</param>
/// <returns>Long with the remaining bytes for storage or will throw a BadRequestException if user cannot access
/// file or email is not verified.
/// </returns>
Task<long> StorageRemainingForSendAsync(Send send);
}

View File

@ -0,0 +1,101 @@
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Platform.Push;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Microsoft.AspNetCore.Identity;
namespace Bit.Core.Tools.Services;
public class SendAuthorizationService : ISendAuthorizationService
{
private readonly ISendRepository _sendRepository;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IPushNotificationService _pushNotificationService;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
public SendAuthorizationService(
ISendRepository sendRepository,
IPasswordHasher<User> passwordHasher,
IPushNotificationService pushNotificationService,
IReferenceEventService referenceEventService,
ICurrentContext currentContext)
{
_sendRepository = sendRepository;
_passwordHasher = passwordHasher;
_pushNotificationService = pushNotificationService;
_referenceEventService = referenceEventService;
_currentContext = currentContext;
}
public SendAccessResult SendCanBeAccessed(Send send,
string password)
{
var now = DateTime.UtcNow;
if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount ||
send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled ||
send.DeletionDate < now)
{
return SendAccessResult.Denied;
}
if (!string.IsNullOrWhiteSpace(send.Password))
{
if (string.IsNullOrWhiteSpace(password))
{
return SendAccessResult.PasswordRequired;
}
var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password);
if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded)
{
send.Password = HashPassword(password);
}
if (passwordResult == PasswordVerificationResult.Failed)
{
return SendAccessResult.PasswordInvalid;
}
}
return SendAccessResult.Granted;
}
public async Task<SendAccessResult> AccessAsync(Send sendToBeAccessed, string password)
{
var accessResult = SendCanBeAccessed(sendToBeAccessed, password);
if (!accessResult.Equals(SendAccessResult.Granted))
{
return accessResult;
}
if (sendToBeAccessed.Type != SendType.File)
{
// File sends are incremented during file download
sendToBeAccessed.AccessCount++;
}
await _sendRepository.ReplaceAsync(sendToBeAccessed);
await _pushNotificationService.PushSyncSendUpdateAsync(sendToBeAccessed);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
{
Id = sendToBeAccessed.UserId ?? default,
Type = ReferenceEventType.SendAccessed,
Source = ReferenceEventSource.User,
SendType = sendToBeAccessed.Type,
MaxAccessCount = sendToBeAccessed.MaxAccessCount,
HasPassword = !string.IsNullOrWhiteSpace(sendToBeAccessed.Password),
SendHasNotes = sendToBeAccessed.Data?.Contains("Notes"),
ClientId = _currentContext.ClientId,
ClientVersion = _currentContext.ClientVersion
});
return accessResult;
}
public string HashPassword(string password)
{
return _passwordHasher.HashPassword(new User(), password);
}
}

View File

@ -0,0 +1,12 @@
using Bit.Core.Utilities;
namespace Bit.Core.Tools.Services;
public class SendCoreHelperService : ISendCoreHelperService
{
public string SecureRandomString(int length, bool useUpperCase, bool useSpecial)
{
return CoreHelpers.SecureRandomString(length, upper: useUpperCase, special: useSpecial);
}
}

View File

@ -0,0 +1,26 @@
using Bit.Core.Tools.Entities;
namespace Bit.Core.Tools.SendFeatures;
/// <summary>
/// SendFileSettingHelper is a static class that provides constants and helper methods (if needed) for managing file
/// settings.
/// </summary>
public static class SendFileSettingHelper
{
/// <summary>
/// The leeway for the file size. This is the calculated 1 megabyte of cushion when doing comparisons of file sizes
/// within the system.
/// </summary>
public const long FILE_SIZE_LEEWAY = 1024L * 1024L; // 1MB
/// <summary>
/// The maximum file size for a file uploaded in a <see cref="Send" />. Units are calculated in bytes but
/// represent 501 megabytes. 1 megabyte is added for cushion to account for file size.
/// </summary>
public const long MAX_FILE_SIZE = Constants.FileSize501mb;
/// <summary>
/// String of the expected file size and to be used when needing to communicate the file size to the client/user.
/// </summary>
public const string MAX_FILE_SIZE_READABLE = "500 MB";
}

View File

@ -0,0 +1,142 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities;
using Bit.Core.Utilities;
namespace Bit.Core.Tools.Services;
public class SendValidationService : ISendValidationService
{
private readonly IUserRepository _userRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IPolicyService _policyService;
private readonly IFeatureService _featureService;
private readonly IUserService _userService;
private readonly GlobalSettings _globalSettings;
private readonly ICurrentContext _currentContext;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
public SendValidationService(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IPolicyService policyService,
IFeatureService featureService,
IUserService userService,
IPolicyRequirementQuery policyRequirementQuery,
GlobalSettings globalSettings,
ICurrentContext currentContext)
{
_userRepository = userRepository;
_organizationRepository = organizationRepository;
_policyService = policyService;
_featureService = featureService;
_userService = userService;
_policyRequirementQuery = policyRequirementQuery;
_globalSettings = globalSettings;
_currentContext = currentContext;
}
public async Task ValidateUserCanSaveAsync(Guid? userId, Send send)
{
if (_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements))
{
await ValidateUserCanSaveAsync_vNext(userId, send);
return;
}
if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true))
{
return;
}
var anyDisableSendPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(userId.Value,
PolicyType.DisableSend);
if (anyDisableSendPolicies)
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
if (send.HideEmail.GetValueOrDefault())
{
var sendOptionsPolicies = await _policyService.GetPoliciesApplicableToUserAsync(userId.Value, PolicyType.SendOptions);
if (sendOptionsPolicies.Any(p => CoreHelpers.LoadClassFromJsonData<SendOptionsPolicyData>(p.PolicyData)?.DisableHideEmail ?? false))
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
}
public async Task ValidateUserCanSaveAsync_vNext(Guid? userId, Send send)
{
if (!userId.HasValue)
{
return;
}
var disableSendRequirement = await _policyRequirementQuery.GetAsync<DisableSendPolicyRequirement>(userId.Value);
if (disableSendRequirement.DisableSend)
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
var sendOptionsRequirement = await _policyRequirementQuery.GetAsync<SendOptionsPolicyRequirement>(userId.Value);
if (sendOptionsRequirement.DisableHideEmail && send.HideEmail.GetValueOrDefault())
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
public async Task<long> StorageRemainingForSendAsync(Send send)
{
var storageBytesRemaining = 0L;
if (send.UserId.HasValue)
{
var user = await _userRepository.GetByIdAsync(send.UserId.Value);
if (!await _userService.CanAccessPremium(user))
{
throw new BadRequestException("You must have premium status to use file Sends.");
}
if (!user.EmailVerified)
{
throw new BadRequestException("You must confirm your email to use file Sends.");
}
if (user.Premium)
{
storageBytesRemaining = user.StorageBytesRemaining();
}
else
{
// Users that get access to file storage/premium from their organization get the default
// 1 GB max storage.
short limit = _globalSettings.SelfHosted ? (short)10240 : (short)1;
storageBytesRemaining = user.StorageBytesRemaining(limit);
}
}
else if (send.OrganizationId.HasValue)
{
var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
if (!org.MaxStorageGb.HasValue)
{
throw new BadRequestException("This organization cannot use file sends.");
}
storageBytesRemaining = org.StorageBytesRemaining();
}
return storageBytesRemaining;
}
}

View File

@ -1,16 +0,0 @@
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Models.Data;
namespace Bit.Core.Tools.Services;
public interface ISendService
{
Task DeleteSendAsync(Send send);
Task SaveSendAsync(Send send);
Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength);
Task UploadFileToExistingSendAsync(Stream stream, Send send);
Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password);
string HashPassword(string password);
Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password);
Task<bool> ValidateSendFile(Send send);
}

View File

@ -1,16 +0,0 @@
using Bit.Core.Enums;
using Bit.Core.Tools.Entities;
namespace Bit.Core.Tools.Services;
public interface ISendFileStorageService
{
FileUploadType FileUploadType { get; }
Task UploadNewFileAsync(Stream stream, Send send, string fileId);
Task DeleteFileAsync(Send send, string fileId);
Task DeleteFilesForOrganizationAsync(Guid organizationId);
Task DeleteFilesForUserAsync(Guid userId);
Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId);
Task<string> GetSendFileUploadUrlAsync(Send send, string fileId);
Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway);
}

View File

@ -1,383 +0,0 @@
using System.Text.Json;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Identity;
namespace Bit.Core.Tools.Services;
public class SendService : ISendService
{
public const long MAX_FILE_SIZE = Constants.FileSize501mb;
public const string MAX_FILE_SIZE_READABLE = "500 MB";
private readonly ISendRepository _sendRepository;
private readonly IUserRepository _userRepository;
private readonly IPolicyService _policyService;
private readonly IUserService _userService;
private readonly IOrganizationRepository _organizationRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IPushNotificationService _pushService;
private readonly IReferenceEventService _referenceEventService;
private readonly GlobalSettings _globalSettings;
private readonly ICurrentContext _currentContext;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IFeatureService _featureService;
private const long _fileSizeLeeway = 1024L * 1024L; // 1MB
public SendService(
ISendRepository sendRepository,
IUserRepository userRepository,
IUserService userService,
IOrganizationRepository organizationRepository,
ISendFileStorageService sendFileStorageService,
IPasswordHasher<User> passwordHasher,
IPushNotificationService pushService,
IReferenceEventService referenceEventService,
GlobalSettings globalSettings,
IPolicyService policyService,
ICurrentContext currentContext,
IPolicyRequirementQuery policyRequirementQuery,
IFeatureService featureService)
{
_sendRepository = sendRepository;
_userRepository = userRepository;
_userService = userService;
_policyService = policyService;
_organizationRepository = organizationRepository;
_sendFileStorageService = sendFileStorageService;
_passwordHasher = passwordHasher;
_pushService = pushService;
_referenceEventService = referenceEventService;
_globalSettings = globalSettings;
_currentContext = currentContext;
_policyRequirementQuery = policyRequirementQuery;
_featureService = featureService;
}
public async Task SaveSendAsync(Send send)
{
// Make sure user can save Sends
await ValidateUserCanSaveAsync(send.UserId, send);
if (send.Id == default(Guid))
{
await _sendRepository.CreateAsync(send);
await _pushService.PushSyncSendCreateAsync(send);
await RaiseReferenceEventAsync(send, ReferenceEventType.SendCreated);
}
else
{
send.RevisionDate = DateTime.UtcNow;
await _sendRepository.UpsertAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
}
}
public async Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Send is not of type \"file\".");
}
if (fileLength < 1)
{
throw new BadRequestException("No file data.");
}
var storageBytesRemaining = await StorageRemainingForSendAsync(send);
if (storageBytesRemaining < fileLength)
{
throw new BadRequestException("Not enough storage available.");
}
var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false);
try
{
data.Id = fileId;
data.Size = fileLength;
data.Validated = false;
send.Data = JsonSerializer.Serialize(data,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId);
}
catch
{
// Clean up since this is not transactional
await _sendFileStorageService.DeleteFileAsync(send, fileId);
throw;
}
}
public async Task UploadFileToExistingSendAsync(Stream stream, Send send)
{
if (send?.Data == null)
{
throw new BadRequestException("Send does not have file data");
}
if (send.Type != SendType.File)
{
throw new BadRequestException("Not a File Type Send.");
}
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
if (data.Validated)
{
throw new BadRequestException("File has already been uploaded.");
}
await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id);
if (!await ValidateSendFile(send))
{
throw new BadRequestException("File received does not match expected file length.");
}
}
public async Task<bool> ValidateSendFile(Send send)
{
var fileData = JsonSerializer.Deserialize<SendFileData>(send.Data);
var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, _fileSizeLeeway);
if (!valid || realSize > MAX_FILE_SIZE)
{
// File reported differs in size from that promised. Must be a rogue client. Delete Send
await DeleteSendAsync(send);
return false;
}
// Update Send data if necessary
if (realSize != fileData.Size)
{
fileData.Size = realSize.Value;
}
fileData.Validated = true;
send.Data = JsonSerializer.Serialize(fileData,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return valid;
}
public async Task DeleteSendAsync(Send send)
{
await _sendRepository.DeleteAsync(send);
if (send.Type == Enums.SendType.File)
{
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
await _sendFileStorageService.DeleteFileAsync(send, data.Id);
}
await _pushService.PushSyncSendDeleteAsync(send);
}
public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send,
string password)
{
var now = DateTime.UtcNow;
if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount ||
send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled ||
send.DeletionDate < now)
{
return (false, false, false);
}
if (!string.IsNullOrWhiteSpace(send.Password))
{
if (string.IsNullOrWhiteSpace(password))
{
return (false, true, false);
}
var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password);
if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded)
{
send.Password = HashPassword(password);
}
if (passwordResult == PasswordVerificationResult.Failed)
{
return (false, false, true);
}
}
return (true, false, false);
}
// Response: Send, password required, password invalid
public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Can only get a download URL for a file type of Send");
}
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
if (!grantAccess)
{
return (null, passwordRequired, passwordInvalid);
}
send.AccessCount++;
await _sendRepository.ReplaceAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false);
}
// Response: Send, password required, password invalid
public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password)
{
var send = await _sendRepository.GetByIdAsync(sendId);
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
if (!grantAccess)
{
return (null, passwordRequired, passwordInvalid);
}
// TODO: maybe move this to a simple ++ sproc?
if (send.Type != SendType.File)
{
// File sends are incremented during file download
send.AccessCount++;
}
await _sendRepository.ReplaceAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed);
return (send, false, false);
}
private async Task RaiseReferenceEventAsync(Send send, ReferenceEventType eventType)
{
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
{
Id = send.UserId ?? default,
Type = eventType,
Source = ReferenceEventSource.User,
SendType = send.Type,
MaxAccessCount = send.MaxAccessCount,
HasPassword = !string.IsNullOrWhiteSpace(send.Password),
SendHasNotes = send.Data?.Contains("Notes"),
ClientId = _currentContext.ClientId,
ClientVersion = _currentContext.ClientVersion
});
}
public string HashPassword(string password)
{
return _passwordHasher.HashPassword(new User(), password);
}
private async Task ValidateUserCanSaveAsync(Guid? userId, Send send)
{
if (_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements))
{
await ValidateUserCanSaveAsync_vNext(userId, send);
return;
}
if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true))
{
return;
}
var anyDisableSendPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(userId.Value,
PolicyType.DisableSend);
if (anyDisableSendPolicies)
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
if (send.HideEmail.GetValueOrDefault())
{
var sendOptionsPolicies = await _policyService.GetPoliciesApplicableToUserAsync(userId.Value, PolicyType.SendOptions);
if (sendOptionsPolicies.Any(p => CoreHelpers.LoadClassFromJsonData<SendOptionsPolicyData>(p.PolicyData)?.DisableHideEmail ?? false))
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
}
private async Task ValidateUserCanSaveAsync_vNext(Guid? userId, Send send)
{
if (!userId.HasValue)
{
return;
}
var disableSendRequirement = await _policyRequirementQuery.GetAsync<DisableSendPolicyRequirement>(userId.Value);
if (disableSendRequirement.DisableSend)
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
var sendOptionsRequirement = await _policyRequirementQuery.GetAsync<SendOptionsPolicyRequirement>(userId.Value);
if (sendOptionsRequirement.DisableHideEmail && send.HideEmail.GetValueOrDefault())
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
private async Task<long> StorageRemainingForSendAsync(Send send)
{
var storageBytesRemaining = 0L;
if (send.UserId.HasValue)
{
var user = await _userRepository.GetByIdAsync(send.UserId.Value);
if (!await _userService.CanAccessPremium(user))
{
throw new BadRequestException("You must have premium status to use file Sends.");
}
if (!user.EmailVerified)
{
throw new BadRequestException("You must confirm your email to use file Sends.");
}
if (user.Premium)
{
storageBytesRemaining = user.StorageBytesRemaining();
}
else
{
// Users that get access to file storage/premium from their organization get the default
// 1 GB max storage.
storageBytesRemaining = user.StorageBytesRemaining(
_globalSettings.SelfHosted ? (short)10240 : (short)1);
}
}
else if (send.OrganizationId.HasValue)
{
var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
if (!org.MaxStorageGb.HasValue)
{
throw new BadRequestException("This organization cannot use file sends.");
}
storageBytesRemaining = org.StorageBytesRemaining();
}
return storageBytesRemaining;
}
}

View File

@ -43,6 +43,7 @@ using Bit.Core.Settings;
using Bit.Core.Tokens;
using Bit.Core.Tools.ImportFeatures;
using Bit.Core.Tools.ReportFeatures;
using Bit.Core.Tools.SendFeatures;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities;
using Bit.Core.Vault;
@ -123,7 +124,7 @@ public static class ServiceCollectionExtensions
services.AddScoped<ISsoConfigService, SsoConfigService>();
services.AddScoped<IAuthRequestService, AuthRequestService>();
services.AddScoped<IDuoUniversalTokenService, DuoUniversalTokenService>();
services.AddScoped<ISendService, SendService>();
services.AddScoped<ISendAuthorizationService, SendAuthorizationService>();
services.AddLoginServices();
services.AddScoped<IOrganizationDomainService, OrganizationDomainService>();
services.AddVaultServices();
@ -132,6 +133,7 @@ public static class ServiceCollectionExtensions
services.AddNotificationCenterServices();
services.AddPlatformServices();
services.AddImportServices();
services.AddSendServices();
}
public static void AddTokenizers(this IServiceCollection services)

View File

@ -23,11 +23,11 @@ public class SendRotationValidatorTests
public async Task ValidateAsync_Success()
{
// Arrange
var sendService = Substitute.For<ISendService>();
var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
var sendRepository = Substitute.For<ISendRepository>();
var sut = new SendRotationValidator(
sendService,
sendAuthorizationService,
sendRepository
);
@ -52,11 +52,11 @@ public class SendRotationValidatorTests
public async Task ValidateAsync_SendNotReturnedFromRepository_NotIncludedInOutput()
{
// Arrange
var sendService = Substitute.For<ISendService>();
var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
var sendRepository = Substitute.For<ISendRepository>();
var sut = new SendRotationValidator(
sendService,
sendAuthorizationService,
sendRepository
);
@ -76,11 +76,11 @@ public class SendRotationValidatorTests
public async Task ValidateAsync_InputMissingUserSend_Throws()
{
// Arrange
var sendService = Substitute.For<ISendService>();
var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
var sendRepository = Substitute.For<ISendRepository>();
var sut = new SendRotationValidator(
sendService,
sendAuthorizationService,
sendRepository
);

View File

@ -10,7 +10,9 @@ using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures.Commands.Interfaces;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Mvc;
@ -26,7 +28,9 @@ public class SendsControllerTests : IDisposable
private readonly GlobalSettings _globalSettings;
private readonly IUserService _userService;
private readonly ISendRepository _sendRepository;
private readonly ISendService _sendService;
private readonly INonAnonymousSendCommand _nonAnonymousSendCommand;
private readonly IAnonymousSendCommand _anonymousSendCommand;
private readonly ISendAuthorizationService _sendAuthorizationService;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly ILogger<SendsController> _logger;
private readonly ICurrentContext _currentContext;
@ -35,7 +39,9 @@ public class SendsControllerTests : IDisposable
{
_userService = Substitute.For<IUserService>();
_sendRepository = Substitute.For<ISendRepository>();
_sendService = Substitute.For<ISendService>();
_nonAnonymousSendCommand = Substitute.For<INonAnonymousSendCommand>();
_anonymousSendCommand = Substitute.For<IAnonymousSendCommand>();
_sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
_sendFileStorageService = Substitute.For<ISendFileStorageService>();
_globalSettings = new GlobalSettings();
_logger = Substitute.For<ILogger<SendsController>>();
@ -44,7 +50,9 @@ public class SendsControllerTests : IDisposable
_sut = new SendsController(
_sendRepository,
_userService,
_sendService,
_sendAuthorizationService,
_anonymousSendCommand,
_nonAnonymousSendCommand,
_sendFileStorageService,
_logger,
_globalSettings,
@ -68,7 +76,8 @@ public class SendsControllerTests : IDisposable
send.Data = JsonSerializer.Serialize(new Dictionary<string, string>());
send.HideEmail = true;
_sendService.AccessAsync(id, null).Returns((send, false, false));
_sendRepository.GetByIdAsync(Arg.Any<Guid>()).Returns(send);
_sendAuthorizationService.AccessAsync(send, null).Returns(SendAccessResult.Granted);
_userService.GetUserByIdAsync(Arg.Any<Guid>()).Returns(user);
var request = new SendAccessRequestModel();

View File

@ -34,11 +34,11 @@ public class SendRequestModelTests
Type = SendType.Text,
};
var sendService = Substitute.For<ISendService>();
sendService.HashPassword(Arg.Any<string>())
var sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
sendAuthorizationService.HashPassword(Arg.Any<string>())
.Returns((info) => $"hashed_{(string)info[0]}");
var send = sendRequest.ToSend(Guid.NewGuid(), sendService);
var send = sendRequest.ToSend(Guid.NewGuid(), sendAuthorizationService);
Assert.Equal(deletionDate, send.DeletionDate);
Assert.False(send.Disabled);

File diff suppressed because it is too large Load Diff

View File

@ -3,14 +3,11 @@ using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Caches;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Services.Implementations;
using Bit.Core.Billing.Tax.Models;
using Bit.Core.Billing.Tax.Services;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Test.Billing.Tax.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Braintree;
@ -195,7 +192,7 @@ public class SubscriberServiceTests
await stripeAdapter
.DidNotReceiveWithAnyArgs()
.SubscriptionCancelAsync(Arg.Any<string>(), Arg.Any<SubscriptionCancelOptions>()); ;
.SubscriptionCancelAsync(Arg.Any<string>(), Arg.Any<SubscriptionCancelOptions>());
}
#endregion
@ -1029,7 +1026,7 @@ public class SubscriberServiceTests
stripeAdapter
.PaymentMethodListAutoPagingAsync(Arg.Any<PaymentMethodListOptions>())
.Returns(GetPaymentMethodsAsync(new List<Stripe.PaymentMethod>()));
.Returns(GetPaymentMethodsAsync(new List<PaymentMethod>()));
await sutProvider.Sut.RemovePaymentSource(organization);
@ -1061,7 +1058,7 @@ public class SubscriberServiceTests
stripeAdapter
.PaymentMethodListAutoPagingAsync(Arg.Any<PaymentMethodListOptions>())
.Returns(GetPaymentMethodsAsync(new List<Stripe.PaymentMethod>
.Returns(GetPaymentMethodsAsync(new List<PaymentMethod>
{
new ()
{
@ -1086,8 +1083,8 @@ public class SubscriberServiceTests
.PaymentMethodDetachAsync(cardId);
}
private static async IAsyncEnumerable<Stripe.PaymentMethod> GetPaymentMethodsAsync(
IEnumerable<Stripe.PaymentMethod> paymentMethods)
private static async IAsyncEnumerable<PaymentMethod> GetPaymentMethodsAsync(
IEnumerable<PaymentMethod> paymentMethods)
{
foreach (var paymentMethod in paymentMethods)
{
@ -1598,14 +1595,22 @@ public class SubscriberServiceTests
City = "Example Town",
State = "NY"
},
TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] }
TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] },
Subscriptions = new StripeList<Subscription>
{
Data = [
new Subscription
{
Id = provider.GatewaySubscriptionId,
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
}
]
}
});
var subscription = new Subscription { Items = new StripeList<SubscriptionItem>() };
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(Arg.Any<string>())
.Returns(subscription);
sutProvider.GetDependency<IAutomaticTaxFactory>().CreateAsync(Arg.Any<AutomaticTaxFactoryParameters>())
.Returns(new FakeAutomaticTaxStrategy(true));
await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation);
@ -1623,6 +1628,98 @@ public class SubscriberServiceTests
await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is<TaxIdCreateOptions>(
options => options.Type == "us_ein" &&
options.Value == taxInformation.TaxId));
await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true));
}
[Theory, BitAutoData]
public async Task UpdateTaxInformation_NonUser_ReverseCharge_MakesCorrectInvocations(
Provider provider,
SutProvider<SubscriberService> sutProvider)
{
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } };
stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is<CustomerGetOptions>(
options => options.Expand.Contains("tax_ids"))).Returns(customer);
var taxInformation = new TaxInformation(
"CA",
"12345",
"123456789",
"us_ein",
"123 Example St.",
null,
"Example Town",
"NY");
sutProvider.GetDependency<IStripeAdapter>()
.CustomerUpdateAsync(
Arg.Is<string>(p => p == provider.GatewayCustomerId),
Arg.Is<CustomerUpdateOptions>(options =>
options.Address.Country == "CA" &&
options.Address.PostalCode == "12345" &&
options.Address.Line1 == "123 Example St." &&
options.Address.Line2 == null &&
options.Address.City == "Example Town" &&
options.Address.State == "NY"))
.Returns(new Customer
{
Id = provider.GatewayCustomerId,
Address = new Address
{
Country = "CA",
PostalCode = "12345",
Line1 = "123 Example St.",
Line2 = null,
City = "Example Town",
State = "NY"
},
TaxIds = new StripeList<TaxId> { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] },
Subscriptions = new StripeList<Subscription>
{
Data = [
new Subscription
{
Id = provider.GatewaySubscriptionId,
CustomerId = provider.GatewayCustomerId,
AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }
}
]
}
});
var subscription = new Subscription { Items = new StripeList<SubscriptionItem>() };
sutProvider.GetDependency<IStripeAdapter>().SubscriptionGetAsync(Arg.Any<string>())
.Returns(subscription);
sutProvider.GetDependency<IFeatureService>()
.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true);
await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation);
await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is<CustomerUpdateOptions>(
options =>
options.Address.Country == taxInformation.Country &&
options.Address.PostalCode == taxInformation.PostalCode &&
options.Address.Line1 == taxInformation.Line1 &&
options.Address.Line2 == taxInformation.Line2 &&
options.Address.City == taxInformation.City &&
options.Address.State == taxInformation.State));
await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1");
await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is<TaxIdCreateOptions>(
options => options.Type == "us_ein" &&
options.Value == taxInformation.TaxId));
await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId,
Arg.Is<CustomerUpdateOptions>(options => options.TaxExempt == StripeConstants.TaxExempt.Reverse));
await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
Arg.Is<SubscriptionUpdateOptions>(options => options.AutomaticTax.Enabled == true));
}
#endregion

View File

@ -0,0 +1,118 @@
using System.Text.Json;
using Bit.Core.Exceptions;
using Bit.Core.Platform.Push;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.SendFeatures.Commands;
using Bit.Core.Tools.Services;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Tools.Services;
public class AnonymousSendCommandTests
{
private readonly ISendRepository _sendRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPushNotificationService _pushNotificationService;
private readonly ISendAuthorizationService _sendAuthorizationService;
private readonly AnonymousSendCommand _anonymousSendCommand;
public AnonymousSendCommandTests()
{
_sendRepository = Substitute.For<ISendRepository>();
_sendFileStorageService = Substitute.For<ISendFileStorageService>();
_pushNotificationService = Substitute.For<IPushNotificationService>();
_sendAuthorizationService = Substitute.For<ISendAuthorizationService>();
_anonymousSendCommand = new AnonymousSendCommand(
_sendRepository,
_sendFileStorageService,
_pushNotificationService,
_sendAuthorizationService);
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_Success_ReturnsDownloadUrl()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.File,
AccessCount = 0,
Data = JsonSerializer.Serialize(new { Id = "fileId123" })
};
var fileId = "fileId123";
var password = "testPassword";
var expectedUrl = "https://example.com/download";
_sendAuthorizationService
.SendCanBeAccessed(send, password)
.Returns(SendAccessResult.Granted);
_sendFileStorageService
.GetSendFileDownloadUrlAsync(send, fileId)
.Returns(expectedUrl);
// Act
var result =
await _anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId, password);
// Assert
Assert.Equal(expectedUrl, result.Item1);
Assert.Equal(1, send.AccessCount);
await _sendRepository.Received(1).ReplaceAsync(send);
await _pushNotificationService.Received(1).PushSyncSendUpdateAsync(send);
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_AccessDenied_ReturnsNullWithReasons()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.File,
AccessCount = 0
};
var fileId = "fileId123";
var password = "wrongPassword";
_sendAuthorizationService
.SendCanBeAccessed(send, password)
.Returns(SendAccessResult.Denied);
// Act
var result =
await _anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId, password);
// Assert
Assert.Null(result.Item1);
Assert.Equal(SendAccessResult.Denied, result.Item2);
Assert.Equal(0, send.AccessCount);
await _sendRepository.DidNotReceiveWithAnyArgs().ReplaceAsync(default);
await _pushNotificationService.DidNotReceiveWithAnyArgs().PushSyncSendUpdateAsync(default);
}
[Fact]
public async Task GetSendFileDownloadUrlAsync_NotFileSend_ThrowsBadRequestException()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
Type = SendType.Text
};
var fileId = "fileId123";
var password = "testPassword";
// Act & Assert
await Assert.ThrowsAsync<BadRequestException>(() =>
_anonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId, password));
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,175 @@
using Bit.Core.Context;
using Bit.Core.Platform.Push;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.Services;
using Microsoft.AspNetCore.Identity;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Tools.Services;
public class SendAuthorizationServiceTests
{
private readonly ISendRepository _sendRepository;
private readonly IPasswordHasher<Bit.Core.Entities.User> _passwordHasher;
private readonly IPushNotificationService _pushNotificationService;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
private readonly SendAuthorizationService _sendAuthorizationService;
public SendAuthorizationServiceTests()
{
_sendRepository = Substitute.For<ISendRepository>();
_passwordHasher = Substitute.For<IPasswordHasher<Bit.Core.Entities.User>>();
_pushNotificationService = Substitute.For<IPushNotificationService>();
_referenceEventService = Substitute.For<IReferenceEventService>();
_currentContext = Substitute.For<ICurrentContext>();
_sendAuthorizationService = new SendAuthorizationService(
_sendRepository,
_passwordHasher,
_pushNotificationService,
_referenceEventService,
_currentContext);
}
[Fact]
public void SendCanBeAccessed_Success_ReturnsTrue()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
MaxAccessCount = 10,
AccessCount = 5,
ExpirationDate = DateTime.UtcNow.AddYears(1),
DeletionDate = DateTime.UtcNow.AddYears(1),
Disabled = false,
Password = "hashedPassword123"
};
const string password = "TEST";
_passwordHasher
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), send.Password, password)
.Returns(PasswordVerificationResult.Success);
// Act
var result =
_sendAuthorizationService.SendCanBeAccessed(send, password);
// Assert
Assert.Equal(SendAccessResult.Granted, result);
}
[Fact]
public void SendCanBeAccessed_NullMaxAccess_Success()
{
// Arrange
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
MaxAccessCount = null,
AccessCount = 5,
ExpirationDate = DateTime.UtcNow.AddYears(1),
DeletionDate = DateTime.UtcNow.AddYears(1),
Disabled = false,
Password = "hashedPassword123"
};
const string password = "TEST";
_passwordHasher
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), send.Password, password)
.Returns(PasswordVerificationResult.Success);
// Act
var result = _sendAuthorizationService.SendCanBeAccessed(send, password);
// Assert
Assert.Equal(SendAccessResult.Granted, result);
}
[Fact]
public void SendCanBeAccessed_NullSend_DoesNotGrantAccess()
{
// Arrange
_passwordHasher
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.Success);
// Act
var result =
_sendAuthorizationService.SendCanBeAccessed(null, "TEST");
// Assert
Assert.Equal(SendAccessResult.Denied, result);
}
[Fact]
public void SendCanBeAccessed_RehashNeeded_RehashesPassword()
{
// Arrange
var now = DateTime.UtcNow;
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
MaxAccessCount = null,
AccessCount = 5,
ExpirationDate = now.AddYears(1),
DeletionDate = now.AddYears(1),
Disabled = false,
Password = "TEST"
};
_passwordHasher
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.SuccessRehashNeeded);
// Act
var result =
_sendAuthorizationService.SendCanBeAccessed(send, "TEST");
// Assert
_passwordHasher
.Received(1)
.HashPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST");
Assert.Equal(SendAccessResult.Granted, result);
}
[Fact]
public void SendCanBeAccessed_VerifyFailed_PasswordInvalidReturnsTrue()
{
// Arrange
var now = DateTime.UtcNow;
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
MaxAccessCount = null,
AccessCount = 5,
ExpirationDate = now.AddYears(1),
DeletionDate = now.AddYears(1),
Disabled = false,
Password = "TEST"
};
_passwordHasher
.VerifyHashedPassword(Arg.Any<Bit.Core.Entities.User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.Failed);
// Act
var result =
_sendAuthorizationService.SendCanBeAccessed(send, "TEST");
// Assert
Assert.Equal(SendAccessResult.PasswordInvalid, result);
}
}

View File

@ -1,867 +0,0 @@
using System.Text;
using System.Text.Json;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Test.AutoFixture.CurrentContextFixtures;
using Bit.Core.Test.Entities;
using Bit.Core.Test.Tools.AutoFixture.SendFixtures;
using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Data;
using Bit.Core.Tools.Repositories;
using Bit.Core.Tools.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Identity;
using NSubstitute;
using NSubstitute.ExceptionExtensions;
using Xunit;
using GlobalSettings = Bit.Core.Settings.GlobalSettings;
namespace Bit.Core.Test.Tools.Services;
[SutProviderCustomize]
[CurrentContextCustomize]
[UserSendCustomize]
public class SendServiceTests
{
private void SaveSendAsync_Setup(SendType sendType, bool disableSendPolicyAppliesToUser,
SutProvider<SendService> sutProvider, Send send)
{
send.Id = default;
send.Type = sendType;
sutProvider.GetDependency<IPolicyService>().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.DisableSend).Returns(disableSendPolicyAppliesToUser);
}
// Disable Send policy check
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableSend_Applies_throws(SendType sendType,
SutProvider<SendService> sutProvider, Send send)
{
SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: true, sutProvider, send);
await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableSend_DoesntApply_success(SendType sendType,
SutProvider<SendService> sutProvider, Send send)
{
SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: false, sutProvider, send);
await sutProvider.Sut.SaveSendAsync(send);
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
}
// Send Options Policy - Disable Hide Email check
private void SaveSendAsync_HideEmail_Setup(bool disableHideEmailAppliesToUser,
SutProvider<SendService> sutProvider, Send send, Policy policy)
{
send.HideEmail = true;
var sendOptions = new SendOptionsPolicyData
{
DisableHideEmail = disableHideEmailAppliesToUser
};
policy.Data = JsonSerializer.Serialize(sendOptions, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
});
sutProvider.GetDependency<IPolicyService>().GetPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), PolicyType.SendOptions).Returns(new List<OrganizationUserPolicyDetails>()
{
new() { PolicyType = policy.Type, PolicyData = policy.Data, OrganizationId = policy.OrganizationId, PolicyEnabled = policy.Enabled }
});
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableHideEmail_Applies_throws(SendType sendType,
SutProvider<SendService> sutProvider, Send send, Policy policy)
{
SaveSendAsync_Setup(sendType, false, sutProvider, send);
SaveSendAsync_HideEmail_Setup(true, sutProvider, send, policy);
await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableHideEmail_DoesntApply_success(SendType sendType,
SutProvider<SendService> sutProvider, Send send, Policy policy)
{
SaveSendAsync_Setup(sendType, false, sutProvider, send);
SaveSendAsync_HideEmail_Setup(false, sutProvider, send, policy);
await sutProvider.Sut.SaveSendAsync(send);
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
}
// Disable Send policy check - vNext
private void SaveSendAsync_Setup_vNext(SutProvider<SendService> sutProvider, Send send,
DisableSendPolicyRequirement disableSendPolicyRequirement, SendOptionsPolicyRequirement sendOptionsPolicyRequirement)
{
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<DisableSendPolicyRequirement>(send.UserId!.Value)
.Returns(disableSendPolicyRequirement);
sutProvider.GetDependency<IPolicyRequirementQuery>().GetAsync<SendOptionsPolicyRequirement>(send.UserId!.Value)
.Returns(sendOptionsPolicyRequirement);
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
// Should not be called in these tests
sutProvider.GetDependency<IPolicyService>().AnyPoliciesApplicableToUserAsync(
Arg.Any<Guid>(), Arg.Any<PolicyType>()).ThrowsAsync<Exception>();
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableSend_Applies_Throws_vNext(SendType sendType,
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
{
send.Type = sendType;
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement { DisableSend = true }, new SendOptionsPolicyRequirement());
var exception = await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
Assert.Contains("Due to an Enterprise Policy, you are only able to delete an existing Send.",
exception.Message);
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableSend_DoesntApply_Success_vNext(SendType sendType,
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
{
send.Type = sendType;
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement());
await sutProvider.Sut.SaveSendAsync(send);
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
}
// Send Options Policy - Disable Hide Email check
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableHideEmail_Applies_Throws_vNext(SendType sendType,
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
{
send.Type = sendType;
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement { DisableHideEmail = true });
send.HideEmail = true;
var exception = await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.SaveSendAsync(send));
Assert.Contains("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.", exception.Message);
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableHideEmail_Applies_ButEmailNotHidden_Success_vNext(SendType sendType,
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
{
send.Type = sendType;
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement { DisableHideEmail = true });
send.HideEmail = false;
await sutProvider.Sut.SaveSendAsync(send);
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
}
[Theory]
[BitAutoData(SendType.File)]
[BitAutoData(SendType.Text)]
public async Task SaveSendAsync_DisableHideEmail_DoesntApply_Success_vNext(SendType sendType,
SutProvider<SendService> sutProvider, [NewUserSendCustomize] Send send)
{
send.Type = sendType;
SaveSendAsync_Setup_vNext(sutProvider, send, new DisableSendPolicyRequirement(), new SendOptionsPolicyRequirement());
send.HideEmail = true;
await sutProvider.Sut.SaveSendAsync(send);
await sutProvider.GetDependency<ISendRepository>().Received(1).CreateAsync(send);
}
[Theory]
[BitAutoData]
public async Task SaveSendAsync_ExistingSend_Updates(SutProvider<SendService> sutProvider,
Send send)
{
send.Id = Guid.NewGuid();
var now = DateTime.UtcNow;
await sutProvider.Sut.SaveSendAsync(send);
Assert.True(send.RevisionDate - now < TimeSpan.FromSeconds(1));
await sutProvider.GetDependency<ISendRepository>()
.Received(1)
.UpsertAsync(send);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushSyncSendUpdateAsync(send);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_TextType_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
send.Type = SendType.Text;
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 0)
);
Assert.Contains("not of type \"file\"", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_EmptyFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
send.Type = SendType.File;
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 0)
);
Assert.Contains("no file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserCannotAccessPremium_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(false);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("must have premium", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserHasUnconfirmedEmail_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = false,
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("must confirm your email", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserCanAccessPremium_HasNoStorage_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
Premium = true,
MaxStorageGb = null,
Storage = 0,
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserCanAccessPremium_StorageFull_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
Premium = true,
MaxStorageGb = 2,
Storage = 2 * UserTests.Multiplier,
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsSelfHosted_GiantFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
Premium = false,
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<GlobalSettings>()
.SelfHosted = true;
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 11000 * UserTests.Multiplier)
);
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsNotSelfHosted_TwoGigabyteFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
Premium = false,
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<GlobalSettings>()
.SelfHosted = false;
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier)
);
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var org = new Organization
{
Id = Guid.NewGuid(),
MaxStorageGb = null,
};
send.UserId = null;
send.OrganizationId = org.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(org.Id)
.Returns(org);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_TwoGBFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var org = new Organization
{
Id = Guid.NewGuid(),
MaxStorageGb = null,
};
send.UserId = null;
send.OrganizationId = org.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(org.Id)
.Returns(org);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 1)
);
Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_ThroughOrg_MaxStorageIsOneGB_TwoGBFile_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var org = new Organization
{
Id = Guid.NewGuid(),
MaxStorageGb = 1,
};
send.UserId = null;
send.OrganizationId = org.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(org.Id)
.Returns(org);
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier)
);
Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_HasEnoughStorage_Success(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
MaxStorageGb = 10,
};
var data = new SendFileData
{
};
send.UserId = user.Id;
send.Type = SendType.File;
var testUrl = "https://test.com/";
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<ISendFileStorageService>()
.GetSendFileUploadUrlAsync(send, Arg.Any<string>())
.Returns(testUrl);
var utcNow = DateTime.UtcNow;
var url = await sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier);
Assert.Equal(testUrl, url);
Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1));
await sutProvider.GetDependency<ISendFileStorageService>()
.Received(1)
.GetSendFileUploadUrlAsync(send, Arg.Any<string>());
await sutProvider.GetDependency<ISendRepository>()
.Received(1)
.UpsertAsync(send);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushSyncSendUpdateAsync(send);
}
[Theory]
[BitAutoData]
public async Task SaveFileSendAsync_HasEnoughStorage_SendFileThrows_CleansUp(SutProvider<SendService> sutProvider,
Send send)
{
var user = new User
{
Id = Guid.NewGuid(),
EmailVerified = true,
MaxStorageGb = 10,
};
var data = new SendFileData
{
};
send.UserId = user.Id;
send.Type = SendType.File;
sutProvider.GetDependency<IUserRepository>()
.GetByIdAsync(user.Id)
.Returns(user);
sutProvider.GetDependency<IUserService>()
.CanAccessPremium(user)
.Returns(true);
sutProvider.GetDependency<ISendFileStorageService>()
.GetSendFileUploadUrlAsync(send, Arg.Any<string>())
.Returns<string>(callInfo => throw new Exception("Problem"));
var utcNow = DateTime.UtcNow;
var exception = await Assert.ThrowsAsync<Exception>(() =>
sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier)
);
Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1));
Assert.Equal("Problem", exception.Message);
await sutProvider.GetDependency<ISendFileStorageService>()
.Received(1)
.GetSendFileUploadUrlAsync(send, Arg.Any<string>());
await sutProvider.GetDependency<ISendRepository>()
.Received(1)
.UpsertAsync(send);
await sutProvider.GetDependency<IPushNotificationService>()
.Received(1)
.PushSyncSendUpdateAsync(send);
await sutProvider.GetDependency<ISendFileStorageService>()
.Received(1)
.DeleteFileAsync(send, Arg.Any<string>());
}
[Theory]
[BitAutoData]
public async Task UpdateFileToExistingSendAsync_SendNull_ThrowsBadRequest(SutProvider<SendService> sutProvider)
{
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), null)
);
Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task UpdateFileToExistingSendAsync_SendDataNull_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
send.Data = null;
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send)
);
Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task UpdateFileToExistingSendAsync_NotFileType_ThrowsBadRequest(SutProvider<SendService> sutProvider,
Send send)
{
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send)
);
Assert.Contains("not a file type send", badRequest.Message, StringComparison.InvariantCultureIgnoreCase);
}
[Theory]
[BitAutoData]
public async Task UpdateFileToExistingSendAsync_Success(SutProvider<SendService> sutProvider,
Send send)
{
var fileContents = "Test file content";
var sendFileData = new SendFileData
{
Id = "TEST",
Size = fileContents.Length,
Validated = false,
};
send.Type = SendType.File;
send.Data = JsonSerializer.Serialize(sendFileData);
sutProvider.GetDependency<ISendFileStorageService>()
.ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any<long>())
.Returns((true, sendFileData.Size));
await sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send);
}
[Theory]
[BitAutoData]
public async Task UpdateFileToExistingSendAsync_InvalidSize(SutProvider<SendService> sutProvider,
Send send)
{
var fileContents = "Test file content";
var sendFileData = new SendFileData
{
Id = "TEST",
Size = fileContents.Length,
};
send.Type = SendType.File;
send.Data = JsonSerializer.Serialize(sendFileData);
sutProvider.GetDependency<ISendFileStorageService>()
.ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any<long>())
.Returns((false, sendFileData.Size));
var badRequest = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send)
);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_Success(SutProvider<SendService> sutProvider, Send send)
{
var now = DateTime.UtcNow;
send.MaxAccessCount = 10;
send.AccessCount = 5;
send.ExpirationDate = now.AddYears(1);
send.DeletionDate = now.AddYears(1);
send.Disabled = false;
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), send.Password, "TEST")
.Returns(PasswordVerificationResult.Success);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
Assert.True(grant);
Assert.False(passwordRequiredError);
Assert.False(passwordInvalidError);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_NullMaxAccess_Success(SutProvider<SendService> sutProvider,
Send send)
{
var now = DateTime.UtcNow;
send.MaxAccessCount = null;
send.AccessCount = 5;
send.ExpirationDate = now.AddYears(1);
send.DeletionDate = now.AddYears(1);
send.Disabled = false;
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), send.Password, "TEST")
.Returns(PasswordVerificationResult.Success);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
Assert.True(grant);
Assert.False(passwordRequiredError);
Assert.False(passwordInvalidError);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_NullSend_DoesNotGrantAccess(SutProvider<SendService> sutProvider)
{
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.Success);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(null, "TEST");
Assert.False(grant);
Assert.False(passwordRequiredError);
Assert.False(passwordInvalidError);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_NullPassword_PasswordRequiredErrorReturnsTrue(SutProvider<SendService> sutProvider,
Send send)
{
var now = DateTime.UtcNow;
send.MaxAccessCount = null;
send.AccessCount = 5;
send.ExpirationDate = now.AddYears(1);
send.DeletionDate = now.AddYears(1);
send.Disabled = false;
send.Password = "HASH";
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.Success);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(send, null);
Assert.False(grant);
Assert.True(passwordRequiredError);
Assert.False(passwordInvalidError);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_RehashNeeded_RehashesPassword(SutProvider<SendService> sutProvider,
Send send)
{
var now = DateTime.UtcNow;
send.MaxAccessCount = null;
send.AccessCount = 5;
send.ExpirationDate = now.AddYears(1);
send.DeletionDate = now.AddYears(1);
send.Disabled = false;
send.Password = "TEST";
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.SuccessRehashNeeded);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
sutProvider.GetDependency<IPasswordHasher<User>>()
.Received(1)
.HashPassword(Arg.Any<User>(), "TEST");
Assert.True(grant);
Assert.False(passwordRequiredError);
Assert.False(passwordInvalidError);
}
[Theory]
[BitAutoData]
public void SendCanBeAccessed_VerifyFailed_PasswordInvalidReturnsTrue(SutProvider<SendService> sutProvider,
Send send)
{
var now = DateTime.UtcNow;
send.MaxAccessCount = null;
send.AccessCount = 5;
send.ExpirationDate = now.AddYears(1);
send.DeletionDate = now.AddYears(1);
send.Disabled = false;
send.Password = "TEST";
sutProvider.GetDependency<IPasswordHasher<User>>()
.VerifyHashedPassword(Arg.Any<User>(), "TEST", "TEST")
.Returns(PasswordVerificationResult.Failed);
var (grant, passwordRequiredError, passwordInvalidError)
= sutProvider.Sut.SendCanBeAccessed(send, "TEST");
Assert.False(grant);
Assert.False(passwordRequiredError);
Assert.True(passwordInvalidError);
}
}