diff --git a/Directory.Build.props b/Directory.Build.props index 60d61e5e26..ac814ef8d8 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2025.5.0 + 2025.5.1 Bit.$(MSBuildProjectName) enable diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs index 22a2e93642..35a00f4253 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs @@ -8,13 +8,10 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Services; -using Bit.Core.Billing.Tax.Services.Implementations; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; -using Microsoft.Extensions.DependencyInjection; using Stripe; namespace Bit.Commercial.Core.AdminConsole.Providers; @@ -24,7 +21,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv private readonly IEventService _eventService; private readonly IMailService _mailService; private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; private readonly IProviderOrganizationRepository _providerOrganizationRepository; private readonly IStripeAdapter _stripeAdapter; private readonly IFeatureService _featureService; @@ -32,26 +28,22 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv private readonly ISubscriberService _subscriberService; private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery; private readonly IPricingClient _pricingClient; - private readonly IAutomaticTaxStrategy _automaticTaxStrategy; public RemoveOrganizationFromProviderCommand( IEventService eventService, IMailService mailService, IOrganizationRepository organizationRepository, - IOrganizationService organizationService, IProviderOrganizationRepository providerOrganizationRepository, IStripeAdapter stripeAdapter, IFeatureService featureService, IProviderBillingService providerBillingService, ISubscriberService subscriberService, IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, - IPricingClient pricingClient, - [FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy) + IPricingClient pricingClient) { _eventService = eventService; _mailService = mailService; _organizationRepository = organizationRepository; - _organizationService = organizationService; _providerOrganizationRepository = providerOrganizationRepository; _stripeAdapter = stripeAdapter; _featureService = featureService; @@ -59,7 +51,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv _subscriberService = subscriberService; _hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery; _pricingClient = pricingClient; - _automaticTaxStrategy = automaticTaxStrategy; } public async Task RemoveOrganizationFromProvider( @@ -77,7 +68,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync( providerOrganization.OrganizationId, - Array.Empty(), + [], includeProvider: false)) { throw new BadRequestException("Organization must have at least one confirmed owner."); @@ -102,7 +93,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv /// /// When a client organization is unlinked from a provider, we have to check if they're Stripe-enabled /// and, if they are, we remove their MSP discount and set their Subscription to `send_invoice`. This is because - /// the provider's payment method will be removed from their Stripe customer causing ensuing charges to fail. Lastly, + /// the provider's payment method will be removed from their Stripe customer, causing ensuing charges to fail. Lastly, /// we email the organization owners letting them know they need to add a new payment method. /// private async Task ResetOrganizationBillingAsync( @@ -142,15 +133,18 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }] }; - if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + var setNonUSBusinessUseToReverseCharge = _featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge) { - _automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } - else + else if (customer.HasRecognizedTaxLocation()) { - subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { - Enabled = true + Enabled = customer.Address.Country == "US" || + customer.TaxIds.Any() }; } @@ -187,7 +181,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv await _mailService.SendProviderUpdatePaymentMethod( organization.Id, organization.Name, - provider.Name, + provider.Name!, organizationOwnerEmails); } } diff --git a/bitwarden_license/src/Commercial.Core/Billing/BusinessUnitConverter.cs b/bitwarden_license/src/Commercial.Core/Billing/BusinessUnitConverter.cs index 97d9377cd6..d27b45af4a 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/BusinessUnitConverter.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/BusinessUnitConverter.cs @@ -67,6 +67,7 @@ public class BusinessUnitConverter( organization.MaxStorageGb = updatedPlan.PasswordManager.BaseStorageGb; organization.UsePolicies = updatedPlan.HasPolicies; organization.UseSso = updatedPlan.HasSso; + organization.UseOrganizationDomains = updatedPlan.HasOrganizationDomains; organization.UseGroups = updatedPlan.HasGroups; organization.UseEvents = updatedPlan.HasEvents; organization.UseDirectory = updatedPlan.HasDirectory; diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index b1eefbffe3..c8d6505183 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -18,7 +18,6 @@ using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Services; -using Bit.Core.Billing.Tax.Services.Implementations; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -27,7 +26,6 @@ using Bit.Core.Services; using Bit.Core.Settings; using Braintree; using CsvHelper; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; @@ -52,8 +50,7 @@ public class ProviderBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - ITaxService taxService, - [FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy) + ITaxService taxService) : IProviderBillingService { public async Task AddExistingOrganization( @@ -99,6 +96,7 @@ public class ProviderBillingService( organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; organization.UsePolicies = plan.HasPolicies; organization.UseSso = plan.HasSso; + organization.UseOrganizationDomains = plan.HasOrganizationDomains; organization.UseGroups = plan.HasGroups; organization.UseEvents = plan.HasEvents; organization.UseDirectory = plan.HasDirectory; @@ -127,7 +125,7 @@ public class ProviderBillingService( /* * We have to scale the provider's seats before the ProviderOrganization - * row is inserted so the added organization's seats don't get double counted. + * row is inserted so the added organization's seats don't get double-counted. */ await ScaleSeats(provider, organization.PlanType, organization.Seats!.Value); @@ -235,7 +233,7 @@ public class ProviderBillingService( var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions { - Expand = ["tax_ids"] + Expand = ["tax", "tax_ids"] }); var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); @@ -283,6 +281,13 @@ public class ProviderBillingService( ] }; + var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge && providerCustomer.Address is not { Country: "US" }) + { + customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse; + } + var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); organization.GatewayCustomerId = customer.Id; @@ -519,6 +524,13 @@ public class ProviderBillingService( } }; + var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge && taxInfo.BillingAddressCountry != "US") + { + options.TaxExempt = StripeConstants.TaxExempt.Reverse; + } + if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber)) { var taxIdType = taxService.GetStripeTaxCode( @@ -530,6 +542,7 @@ public class ProviderBillingService( logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber); + throw new BadRequestException("billingTaxIdTypeInferenceError"); } @@ -717,14 +730,21 @@ public class ProviderBillingService( TrialPeriodDays = trialPeriodDays }; - if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) - { - automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); - } - else + var setNonUSBusinessUseToReverseCharge = + featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge) { subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } + else if (customer.HasRecognizedTaxLocation()) + { + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = customer.Address.Country == "US" || + customer.TaxIds.Any() + }; + } try { diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index b450bf5d7f..dd40d7d943 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -1,4 +1,5 @@ using Bit.Commercial.Core.AdminConsole.Providers; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; @@ -8,7 +9,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -224,31 +224,115 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Description == string.Empty && + options.Email == organization.BillingEmail && + options.Expand[0] == "tax" && + options.Expand[1] == "tax_ids")).Returns(new Customer + { + Id = "customer_id", + Address = new Address + { + Country = "US" + } + }); + stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription { Id = "subscription_id" }); - sutProvider.GetDependency() - .When(x => x.SetCreateOptions( - Arg.Is(options => - options.Customer == organization.GatewayCustomerId && - options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && - options.DaysUntilDue == 30 && - options.Metadata["organizationId"] == organization.Id.ToString() && - options.OffSession == true && - options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && - options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId && - options.Items.First().Quantity == organization.Seats) - , Arg.Any())) - .Do(x => + await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); + + await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => + options.Customer == organization.GatewayCustomerId && + options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && + options.DaysUntilDue == 30 && + options.AutomaticTax.Enabled == true && + options.Metadata["organizationId"] == organization.Id.ToString() && + options.OffSession == true && + options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId && + options.Items.First().Quantity == organization.Seats)); + + await sutProvider.GetDependency().Received(1) + .ScaleSeats(provider, organization.PlanType, -organization.Seats ?? 0); + + await organizationRepository.Received(1).ReplaceAsync(Arg.Is( + org => + org.BillingEmail == "a@example.com" && + org.GatewaySubscriptionId == "subscription_id" && + org.Status == OrganizationStatusType.Created)); + + await sutProvider.GetDependency().Received(1) + .DeleteAsync(providerOrganization); + + await sutProvider.GetDependency().Received(1) + .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); + + await sutProvider.GetDependency().Received(1) + .SendProviderUpdatePaymentMethod( + organization.Id, + organization.Name, + provider.Name, + Arg.Is>(emails => emails.FirstOrDefault() == "a@example.com")); + } + + [Theory, BitAutoData] + public async Task RemoveOrganizationFromProvider_OrganizationStripeEnabled_ConsolidatedBilling_ReverseCharge_MakesCorrectInvocations( + Provider provider, + ProviderOrganization providerOrganization, + Organization organization, + SutProvider sutProvider) + { + provider.Status = ProviderStatusType.Billable; + + providerOrganization.ProviderId = provider.Id; + + organization.Status = OrganizationStatusType.Managed; + + organization.PlanType = PlanType.TeamsMonthly; + + var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + + sutProvider.GetDependency().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan); + + sutProvider.GetDependency().HasConfirmedOwnersExceptAsync( + providerOrganization.OrganizationId, + [], + includeProvider: false) + .Returns(true); + + var organizationRepository = sutProvider.GetDependency(); + + organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns([ + "a@example.com", + "b@example.com" + ]); + + var stripeAdapter = sutProvider.GetDependency(); + + stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + options.Description == string.Empty && + options.Email == organization.BillingEmail && + options.Expand[0] == "tax" && + options.Expand[1] == "tax_ids")).Returns(new Customer { - x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions + Id = "customer_id", + Address = new Address { - Enabled = true - }; + Country = "US" + } }); + stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + { + Id = "subscription_id" + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true); + await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index 2199bc4bfe..92094d026e 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -262,7 +262,7 @@ public class ProviderBillingServiceTests }; sutProvider.GetDependency().GetCustomerOrThrow(provider, Arg.Is( - options => options.Expand.FirstOrDefault() == "tax_ids")) + options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids"))) .Returns(providerCustomer); sutProvider.GetDependency().BaseServiceUri @@ -312,6 +312,91 @@ public class ProviderBillingServiceTests org => org.GatewayCustomerId == "customer_id")); } + [Theory, BitAutoData] + public async Task CreateCustomer_ForClientOrg_ReverseCharge_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = null; + organization.Name = "Name"; + organization.BusinessName = "BusinessName"; + + var providerCustomer = new Customer + { + Address = new Address + { + Country = "CA", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Unit 4", + City = "Fake Town", + State = "Fake State" + }, + TaxIds = new StripeList + { + Data = + [ + new TaxId { Type = "TYPE", Value = "VALUE" } + ] + } + }; + + sutProvider.GetDependency().GetCustomerOrThrow(provider, Arg.Is( + options => options.Expand.Contains("tax") && options.Expand.Contains("tax_ids"))) + .Returns(providerCustomer); + + sutProvider.GetDependency().BaseServiceUri + .Returns(new Bit.Core.Settings.GlobalSettings.BaseServiceUriSettings(new Bit.Core.Settings.GlobalSettings()) + { + CloudRegion = "US" + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true); + + sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( + options => + options.Address.Country == providerCustomer.Address.Country && + options.Address.PostalCode == providerCustomer.Address.PostalCode && + options.Address.Line1 == providerCustomer.Address.Line1 && + options.Address.Line2 == providerCustomer.Address.Line2 && + options.Address.City == providerCustomer.Address.City && + options.Address.State == providerCustomer.Address.State && + options.Name == organization.DisplayName() && + options.Description == $"{provider.Name} Client Organization" && + options.Email == provider.BillingEmail && + options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" && + options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" && + options.Metadata["region"] == "US" && + options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type && + options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value && + options.TaxExempt == StripeConstants.TaxExempt.Reverse)) + .Returns(new Customer { Id = "customer_id" }); + + await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); + + await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( + options => + options.Address.Country == providerCustomer.Address.Country && + options.Address.PostalCode == providerCustomer.Address.PostalCode && + options.Address.Line1 == providerCustomer.Address.Line1 && + options.Address.Line2 == providerCustomer.Address.Line2 && + options.Address.City == providerCustomer.Address.City && + options.Address.State == providerCustomer.Address.State && + options.Name == organization.DisplayName() && + options.Description == $"{provider.Name} Client Organization" && + options.Email == provider.BillingEmail && + options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" && + options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" && + options.Metadata["region"] == "US" && + options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type && + options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value)); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.GatewayCustomerId == "customer_id")); + } + #endregion #region GenerateClientInvoiceReport @@ -1182,6 +1267,62 @@ public class ProviderBillingServiceTests Assert.Equivalent(expected, actual); } + [Theory, BitAutoData] + public async Task SetupCustomer_WithCard_ReverseCharge_Success( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + sutProvider.GetDependency() + .GetStripeTaxCode(Arg.Is( + p => p == taxInfo.BillingAddressCountry), + Arg.Is(p => p == taxInfo.TaxIdNumber)) + .Returns(taxInfo.TaxIdType); + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + var expected = new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.PaymentMethod == tokenizedPaymentSource.Token && + o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber && + o.TaxExempt == StripeConstants.TaxExempt.Reverse)) + .Returns(expected); + + var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + + Assert.Equivalent(expected, actual); + } + [Theory, BitAutoData] public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid( SutProvider sutProvider, @@ -1307,7 +1448,7 @@ public class ProviderBillingServiceTests .Returns(new Customer { Id = "customer_id", - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + Address = new Address { Country = "US" } }); var providerPlans = new List @@ -1359,7 +1500,7 @@ public class ProviderBillingServiceTests var customer = new Customer { Id = "customer_id", - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + Address = new Address { Country = "US" } }; sutProvider.GetDependency() .GetCustomerOrThrow( @@ -1399,19 +1540,6 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency() - .When(x => x.SetCreateOptions( - Arg.Is(options => - options.Customer == "customer_id") - , Arg.Is(p => p == customer))) - .Do(x => - { - x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }; - }); - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && @@ -1443,11 +1571,11 @@ public class ProviderBillingServiceTests var customer = new Customer { Id = "customer_id", + Address = new Address { Country = "US" }, InvoiceSettings = new CustomerInvoiceSettings { DefaultPaymentMethodId = "pm_123" - }, - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + } }; sutProvider.GetDependency() @@ -1488,19 +1616,6 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency() - .When(x => x.SetCreateOptions( - Arg.Is(options => - options.Customer == "customer_id") - , Arg.Is(p => p == customer))) - .Do(x => - { - x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }; - }); - sutProvider.GetDependency() .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); @@ -1536,9 +1651,9 @@ public class ProviderBillingServiceTests var customer = new Customer { Id = "customer_id", + Address = new Address { Country = "US" }, InvoiceSettings = new CustomerInvoiceSettings(), - Metadata = new Dictionary(), - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + Metadata = new Dictionary() }; sutProvider.GetDependency() @@ -1579,19 +1694,6 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency() - .When(x => x.SetCreateOptions( - Arg.Is(options => - options.Customer == "customer_id") - , Arg.Is(p => p == customer))) - .Do(x => - { - x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }; - }); - sutProvider.GetDependency() .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); @@ -1646,12 +1748,15 @@ public class ProviderBillingServiceTests var customer = new Customer { Id = "customer_id", + Address = new Address + { + Country = "US" + }, InvoiceSettings = new CustomerInvoiceSettings(), Metadata = new Dictionary { ["btCustomerId"] = "braintree_customer_id" - }, - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + } }; sutProvider.GetDependency() @@ -1692,22 +1797,92 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency() - .When(x => x.SetCreateOptions( - Arg.Is(options => - options.Customer == "customer_id") - , Arg.Is(p => p == customer))) - .Do(x => + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && + sub.Customer == "customer_id" && + sub.DaysUntilDue == null && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + sub.TrialPeriodDays == 14)).Returns(expected); + + var actual = await sutProvider.Sut.SetupSubscription(provider); + + Assert.Equivalent(expected, actual); + } + + [Theory, BitAutoData] + public async Task SetupSubscription_ReverseCharge_Succeeds( + SutProvider sutProvider, + Provider provider) + { + provider.Type = ProviderType.Msp; + provider.GatewaySubscriptionId = null; + + var customer = new Customer + { + Id = "customer_id", + Address = new Address { Country = "CA" }, + InvoiceSettings = new CustomerInvoiceSettings { - x.Arg().AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }; - }); + DefaultPaymentMethodId = "pm_123" + } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow( + provider, + Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer); + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + }, + new() + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + } + }; + + foreach (var plan in providerPlans) + { + sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) + .Returns(StaticStore.GetPlan(plan.PlanType)); + } + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; sutProvider.GetDependency() .IsEnabled(FeatureFlagKeys.PM19956_RequireProviderPaymentMethodDuringSetup).Returns(true); + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true); + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && diff --git a/src/Admin/AdminConsole/Views/Shared/_OrganizationFormScripts.cshtml b/src/Admin/AdminConsole/Views/Shared/_OrganizationFormScripts.cshtml index 0ce25c700f..ea4448d100 100644 --- a/src/Admin/AdminConsole/Views/Shared/_OrganizationFormScripts.cshtml +++ b/src/Admin/AdminConsole/Views/Shared/_OrganizationFormScripts.cshtml @@ -69,7 +69,7 @@ document.getElementById('@(nameof(Model.UseGroups))').checked = plan.hasGroups; document.getElementById('@(nameof(Model.UsePolicies))').checked = plan.hasPolicies; document.getElementById('@(nameof(Model.UseSso))').checked = plan.hasSso; - document.getElementById('@(nameof(Model.UseOrganizationDomains))').checked = hasOrganizationDomains; + document.getElementById('@(nameof(Model.UseOrganizationDomains))').checked = plan.hasOrganizationDomains; document.getElementById('@(nameof(Model.UseScim))').checked = plan.hasScim; document.getElementById('@(nameof(Model.UseDirectory))').checked = plan.hasDirectory; document.getElementById('@(nameof(Model.UseEvents))').checked = plan.hasEvents; diff --git a/src/Api/Billing/Controllers/OrganizationBillingController.cs b/src/Api/Billing/Controllers/OrganizationBillingController.cs index 3ebae433d8..1ae1f2e655 100644 --- a/src/Api/Billing/Controllers/OrganizationBillingController.cs +++ b/src/Api/Billing/Controllers/OrganizationBillingController.cs @@ -292,15 +292,17 @@ public class OrganizationBillingController( sale.Organization.PlanType = plan.Type; sale.Organization.Plan = plan.Name; sale.SubscriptionSetup.SkipTrial = true; - await organizationBillingService.Finalize(sale); + + if (organizationSignup.PaymentMethodType == null || string.IsNullOrEmpty(organizationSignup.PaymentToken)) + { + return Error.BadRequest("A payment method is required to restart the subscription."); + } var org = await organizationRepository.GetByIdAsync(organizationId); Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine."); - if (organizationSignup.PaymentMethodType != null) - { - var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken); - var taxInformation = TaxInformation.From(organizationSignup.TaxInfo); - await organizationBillingService.UpdatePaymentMethod(org, paymentSource, taxInformation); - } + var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken); + var taxInformation = TaxInformation.From(organizationSignup.TaxInfo); + await organizationBillingService.UpdatePaymentMethod(org, paymentSource, taxInformation); + await organizationBillingService.Finalize(sale); return TypedResults.Ok(); } diff --git a/src/Api/Billing/Controllers/OrganizationsController.cs b/src/Api/Billing/Controllers/OrganizationsController.cs index 510f6c2835..bd5ab8cef4 100644 --- a/src/Api/Billing/Controllers/OrganizationsController.cs +++ b/src/Api/Billing/Controllers/OrganizationsController.cs @@ -109,28 +109,6 @@ public class OrganizationsController( return license; } - [HttpPost("{id:guid}/payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostPayment(Guid id, [FromBody] PaymentRequestModel model) - { - if (!await currentContext.EditPaymentMethods(id)) - { - throw new NotFoundException(); - } - - await organizationService.ReplacePaymentMethodAsync(id, model.PaymentToken, - model.PaymentMethodType.Value, new TaxInfo - { - BillingAddressLine1 = model.Line1, - BillingAddressLine2 = model.Line2, - BillingAddressState = model.State, - BillingAddressCity = model.City, - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - TaxIdNumber = model.TaxId, - }); - } - [HttpPost("{id:guid}/upgrade")] [SelfHosted(NotSelfHostedOnly = true)] public async Task PostUpgrade(Guid id, [FromBody] OrganizationUpgradeRequestModel model) diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index 4c27098f38..e31d1dceb7 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -1,11 +1,11 @@ using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Services.Contracts; -using Bit.Core.Billing.Tax.Services; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -25,8 +25,7 @@ public class UpcomingInvoiceHandler( IStripeEventService stripeEventService, IStripeEventUtilityService stripeEventUtilityService, IUserRepository userRepository, - IValidateSponsorshipCommand validateSponsorshipCommand, - IAutomaticTaxFactory automaticTaxFactory) + IValidateSponsorshipCommand validateSponsorshipCommand) : IUpcomingInvoiceHandler { public async Task HandleAsync(Event parsedEvent) @@ -46,6 +45,8 @@ public class UpcomingInvoiceHandler( var (organizationId, userId, providerId) = stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata); + var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + if (organizationId.HasValue) { var organization = await organizationRepository.GetByIdAsync(organizationId.Value); @@ -55,7 +56,7 @@ public class UpcomingInvoiceHandler( return; } - await TryEnableAutomaticTaxAsync(subscription); + await AlignOrganizationTaxConcernsAsync(organization, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge); var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); @@ -100,7 +101,25 @@ public class UpcomingInvoiceHandler( return; } - await TryEnableAutomaticTaxAsync(subscription); + if (!subscription.AutomaticTax.Enabled && subscription.Customer.HasRecognizedTaxLocation()) + { + try + { + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set user's ({UserID}) subscription to automatic tax while processing event with ID {EventID}", + user.Id, + parsedEvent.Id); + } + } if (user.Premium) { @@ -116,7 +135,7 @@ public class UpcomingInvoiceHandler( return; } - await TryEnableAutomaticTaxAsync(subscription); + await AlignProviderTaxConcernsAsync(provider, subscription, parsedEvent.Id, setNonUSBusinessUseToReverseCharge); await SendUpcomingInvoiceEmailsAsync(new List { provider.BillingEmail }, invoice); } @@ -139,50 +158,123 @@ public class UpcomingInvoiceHandler( } } - private async Task TryEnableAutomaticTaxAsync(Subscription subscription) + private async Task AlignOrganizationTaxConcernsAsync( + Organization organization, + Subscription subscription, + string eventId, + bool setNonUSBusinessUseToReverseCharge) { - if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) - { - var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id)); - var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); - var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription); + var nonUSBusinessUse = + organization.PlanType.GetProductTier() != ProductTierType.Families && + subscription.Customer.Address.Country != "US"; - if (updateOptions == null) + bool setAutomaticTaxToEnabled; + + if (setNonUSBusinessUseToReverseCharge) + { + if (nonUSBusinessUse && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse) { - return; + try + { + await stripeFacade.UpdateCustomer(subscription.CustomerId, + new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}", + organization.Id, + eventId); + } } - await stripeFacade.UpdateSubscription(subscription.Id, updateOptions); - return; + setAutomaticTaxToEnabled = true; } - - if (subscription.AutomaticTax.Enabled || - !subscription.Customer.HasBillingLocation() || - await IsNonTaxableNonUSBusinessUseSubscription(subscription)) + else { - return; + setAutomaticTaxToEnabled = + subscription.Customer.HasRecognizedTaxLocation() && + (subscription.Customer.Address.Country == "US" || + (nonUSBusinessUse && subscription.Customer.TaxIds.Any())); } - await stripeFacade.UpdateSubscription(subscription.Id, - new SubscriptionUpdateOptions + if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled) + { + try { - DefaultTaxRates = [], - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set organization's ({OrganizationID}) subscription to automatic tax while processing event with ID {EventID}", + organization.Id, + eventId); + } + } + } - return; + private async Task AlignProviderTaxConcernsAsync( + Provider provider, + Subscription subscription, + string eventId, + bool setNonUSBusinessUseToReverseCharge) + { + bool setAutomaticTaxToEnabled; - async Task IsNonTaxableNonUSBusinessUseSubscription(Subscription localSubscription) + if (setNonUSBusinessUseToReverseCharge) { - var familyPriceIds = (await Task.WhenAll( - pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), - pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually))) - .Select(plan => plan.PasswordManager.StripePlanId); + if (subscription.Customer.Address.Country != "US" && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse) + { + try + { + await stripeFacade.UpdateCustomer(subscription.CustomerId, + new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}", + provider.Id, + eventId); + } + } - return localSubscription.Customer.Address.Country != "US" && - localSubscription.Metadata.ContainsKey(StripeConstants.MetadataKeys.OrganizationId) && - !localSubscription.Items.Select(item => item.Price.Id).Intersect(familyPriceIds).Any() && - !localSubscription.Customer.TaxIds.Any(); + setAutomaticTaxToEnabled = true; + } + else + { + setAutomaticTaxToEnabled = + subscription.Customer.HasRecognizedTaxLocation() && + (subscription.Customer.Address.Country == "US" || + subscription.Customer.TaxIds.Any()); + } + + if (!subscription.AutomaticTax.Enabled && setAutomaticTaxToEnabled) + { + try + { + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set provider's ({ProviderID}) subscription to automatic tax while processing event with ID {EventID}", + provider.Id, + eventId); + } } } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/GetOrganizationUsersClaimedStatusQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/GetOrganizationUsersClaimedStatusQuery.cs index 1dda9483cd..d8c510119a 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/GetOrganizationUsersClaimedStatusQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/GetOrganizationUsersClaimedStatusQuery.cs @@ -24,9 +24,7 @@ public class GetOrganizationUsersClaimedStatusQuery : IGetOrganizationUsersClaim // Users can only be claimed by an Organization that is enabled and can have organization domains var organizationAbility = await _applicationCacheService.GetOrganizationAbilityAsync(organizationId); - // TODO: Replace "UseSso" with a new organization ability like "UseOrganizationDomains" (PM-11622). - // Verified domains were tied to SSO, so we currently check the "UseSso" organization ability. - if (organizationAbility is { Enabled: true, UseSso: true }) + if (organizationAbility is { Enabled: true, UseOrganizationDomains: true }) { // Get all organization users with claimed domains by the organization var organizationUsersWithClaimedDomain = await _organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organizationId); diff --git a/src/Core/AdminConsole/Services/IOrganizationService.cs b/src/Core/AdminConsole/Services/IOrganizationService.cs index 1e53be734e..8baad23f65 100644 --- a/src/Core/AdminConsole/Services/IOrganizationService.cs +++ b/src/Core/AdminConsole/Services/IOrganizationService.cs @@ -11,8 +11,6 @@ namespace Bit.Core.Services; public interface IOrganizationService { - Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType, - TaxInfo taxInfo); Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null); Task ReinstateSubscriptionAsync(Guid organizationId); Task AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb); diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 1ced923b45..4e9d9bdb8a 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -144,27 +144,6 @@ public class OrganizationService : IOrganizationService _sendOrganizationInvitesCommand = sendOrganizationInvitesCommand; } - public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, - PaymentMethodType paymentMethodType, TaxInfo taxInfo) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - await _paymentService.SaveTaxInfoAsync(organization, taxInfo); - var updated = await _paymentService.UpdatePaymentMethodAsync( - organization, - paymentMethodType, - paymentToken, - taxInfo); - if (updated) - { - await ReplaceAndUpdateCacheAsync(organization); - } - } - public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null) { var organization = await GetOrgById(organizationId); diff --git a/src/Core/Auth/Enums/EmergencyAccessStatusType.cs b/src/Core/Auth/Enums/EmergencyAccessStatusType.cs index 7faaa11752..d817d6a950 100644 --- a/src/Core/Auth/Enums/EmergencyAccessStatusType.cs +++ b/src/Core/Auth/Enums/EmergencyAccessStatusType.cs @@ -2,9 +2,24 @@ public enum EmergencyAccessStatusType : byte { + /// + /// The user has been invited to be an emergency contact. + /// Invited = 0, + /// + /// The invited user, "grantee", has accepted the request to be an emergency contact. + /// Accepted = 1, + /// + /// The inviting user, "grantor", has approved the grantee's acceptance. + /// Confirmed = 2, + /// + /// The grantee has initiated the recovery process. + /// RecoveryInitiated = 3, + /// + /// The grantee has excercised their emergency access. + /// RecoveryApproved = 4, } diff --git a/src/Core/Auth/Services/IEmergencyAccessService.cs b/src/Core/Auth/Services/IEmergencyAccessService.cs index 2c94632510..6dd17151e6 100644 --- a/src/Core/Auth/Services/IEmergencyAccessService.cs +++ b/src/Core/Auth/Services/IEmergencyAccessService.cs @@ -3,6 +3,7 @@ using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.Services; using Bit.Core.Vault.Models.Data; @@ -20,6 +21,15 @@ public interface IEmergencyAccessService Task InitiateAsync(Guid id, User initiatingUser); Task ApproveAsync(Guid id, User approvingUser); Task RejectAsync(Guid id, User rejectingUser); + /// + /// This request is made by the Grantee user to fetch the policies for the Grantor User. + /// The Grantor User has to be the owner of the organization. + /// If the Grantor user has OrganizationUserType.Owner then the policies for the _Grantor_ user + /// are returned. + /// + /// EmergencyAccess.Id being acted on + /// User making the request, this is the Grantee + /// null if the GrantorUser is not an organization owner; A list of policies otherwise. Task> GetPoliciesAsync(Guid id, User requestingUser); Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser); Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key); diff --git a/src/Core/Auth/Services/Implementations/EmergencyAccessService.cs b/src/Core/Auth/Services/Implementations/EmergencyAccessService.cs index dda16e29fe..2418830ea7 100644 --- a/src/Core/Auth/Services/Implementations/EmergencyAccessService.cs +++ b/src/Core/Auth/Services/Implementations/EmergencyAccessService.cs @@ -3,7 +3,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; -using Bit.Core.Auth.Models; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; @@ -16,7 +15,6 @@ using Bit.Core.Tokens; using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Repositories; using Bit.Core.Vault.Services; -using Microsoft.AspNetCore.Identity; namespace Bit.Core.Auth.Services; @@ -31,8 +29,6 @@ public class EmergencyAccessService : IEmergencyAccessService private readonly IMailService _mailService; private readonly IUserService _userService; private readonly GlobalSettings _globalSettings; - private readonly IPasswordHasher _passwordHasher; - private readonly IOrganizationService _organizationService; private readonly IDataProtectorTokenFactory _dataProtectorTokenizer; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; @@ -45,9 +41,7 @@ public class EmergencyAccessService : IEmergencyAccessService ICipherService cipherService, IMailService mailService, IUserService userService, - IPasswordHasher passwordHasher, GlobalSettings globalSettings, - IOrganizationService organizationService, IDataProtectorTokenFactory dataProtectorTokenizer, IRemoveOrganizationUserCommand removeOrganizationUserCommand) { @@ -59,9 +53,7 @@ public class EmergencyAccessService : IEmergencyAccessService _cipherService = cipherService; _mailService = mailService; _userService = userService; - _passwordHasher = passwordHasher; _globalSettings = globalSettings; - _organizationService = organizationService; _dataProtectorTokenizer = dataProtectorTokenizer; _removeOrganizationUserCommand = removeOrganizationUserCommand; } @@ -126,7 +118,12 @@ public class EmergencyAccessService : IEmergencyAccessService throw new BadRequestException("Emergency Access not valid."); } - if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email)) + if (!_dataProtectorTokenizer.TryUnprotect(token, out var data)) + { + throw new BadRequestException("Invalid token."); + } + + if (!data.IsValid(emergencyAccessId, user.Email)) { throw new BadRequestException("Invalid token."); } @@ -140,6 +137,8 @@ public class EmergencyAccessService : IEmergencyAccessService throw new BadRequestException("Invitation already accepted."); } + // TODO PM-21687 + // Might not be reachable since the Tokenable.IsValid() does an email comparison if (string.IsNullOrWhiteSpace(emergencyAccess.Email) || !emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) { @@ -163,6 +162,8 @@ public class EmergencyAccessService : IEmergencyAccessService public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId) { var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); + // TODO PM-19438/PM-21687 + // Not sure why the GrantorId and the GranteeId are supposed to be the same? if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId)) { throw new BadRequestException("Emergency Access not valid."); @@ -171,9 +172,9 @@ public class EmergencyAccessService : IEmergencyAccessService await _emergencyAccessRepository.DeleteAsync(emergencyAccess); } - public async Task ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId) + public async Task ConfirmUserAsync(Guid emergencyAccessId, string key, Guid confirmingUserId) { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId); + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted || emergencyAccess.GrantorId != confirmingUserId) { @@ -224,7 +225,6 @@ public class EmergencyAccessService : IEmergencyAccessService public async Task InitiateAsync(Guid id, User initiatingUser) { var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id || emergencyAccess.Status != EmergencyAccessStatusType.Confirmed) { @@ -285,6 +285,9 @@ public class EmergencyAccessService : IEmergencyAccessService public async Task> GetPoliciesAsync(Guid id, User requestingUser) { + // TODO PM-21687 + // Should we look up policies here or just verify the EmergencyAccess is correct + // and handle policy logic else where? Should this be a query/Command? var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) @@ -295,7 +298,9 @@ public class EmergencyAccessService : IEmergencyAccessService var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); - var isOrganizationOwner = grantorOrganizations.Any(organization => organization.Type == OrganizationUserType.Owner); + var isOrganizationOwner = grantorOrganizations + .Any(organization => organization.Type == OrganizationUserType.Owner); + var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null; return policies; @@ -311,7 +316,8 @@ public class EmergencyAccessService : IEmergencyAccessService } var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - + // TODO PM-21687 + // Redundant check of the EmergencyAccessType -> checked in IsValidRequest() ln 308 if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) { throw new BadRequestException("You cannot takeover an account that is using Key Connector."); @@ -336,7 +342,9 @@ public class EmergencyAccessService : IEmergencyAccessService grantor.LastPasswordChangeDate = grantor.RevisionDate; grantor.Key = key; // Disable TwoFactor providers since they will otherwise block logins - grantor.SetTwoFactorProviders(new Dictionary()); + grantor.SetTwoFactorProviders([]); + // Disable New Device Verification since it will otherwise block logins + grantor.VerifyDevices = false; await _userRepository.ReplaceAsync(grantor); // Remove grantor from all organizations unless Owner @@ -421,12 +429,22 @@ public class EmergencyAccessService : IEmergencyAccessService await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token); } - private string NameOrEmail(User user) + private static string NameOrEmail(User user) { return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name; } - private bool IsValidRequest(EmergencyAccess availableAccess, User requestingUser, EmergencyAccessType requestedAccessType) + + /* + * Checks if EmergencyAccess Object is null + * Checks the requesting user is the same as the granteeUser (So we are checking for proper grantee action) + * Status _must_ equal RecoveryApproved (This means the grantor has invited, the grantee has accepted, and the grantor has approved so the shared key exists but hasn't been exercised yet) + * request type must equal the type of access requested (View or Takeover) + */ + private static bool IsValidRequest( + EmergencyAccess availableAccess, + User requestingUser, + EmergencyAccessType requestedAccessType) { return availableAccess != null && availableAccess.GranteeId == requestingUser.Id && diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index c3e3ec6c30..28f4dea4b2 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -2,10 +2,6 @@ public static class StripeConstants { - public static class Prices - { - public const string StoragePlanPersonal = "personal-storage-gb-annually"; - } public static class AutomaticTaxStatus { public const string Failed = "failed"; @@ -69,6 +65,11 @@ public static class StripeConstants public const string USBankAccount = "us_bank_account"; } + public static class Prices + { + public const string StoragePlanPersonal = "personal-storage-gb-annually"; + } + public static class ProrationBehavior { public const string AlwaysInvoice = "always_invoice"; @@ -88,6 +89,13 @@ public static class StripeConstants public const string Paused = "paused"; } + public static class TaxExempt + { + public const string Exempt = "exempt"; + public const string None = "none"; + public const string Reverse = "reverse"; + } + public static class ValidateTaxLocationTiming { public const string Deferred = "deferred"; diff --git a/src/Core/Billing/Extensions/CustomerExtensions.cs b/src/Core/Billing/Extensions/CustomerExtensions.cs index 3e0c1ea0fb..aa22331f7c 100644 --- a/src/Core/Billing/Extensions/CustomerExtensions.cs +++ b/src/Core/Billing/Extensions/CustomerExtensions.cs @@ -15,12 +15,7 @@ public static class CustomerExtensions } }; - /// - /// Determines if a Stripe customer supports automatic tax - /// - /// - /// - public static bool HasTaxLocationVerified(this Customer customer) => + public static bool HasRecognizedTaxLocation(this Customer customer) => customer?.Tax?.AutomaticTax != StripeConstants.AutomaticTaxStatus.UnrecognizedLocation; public static decimal GetBillingBalance(this Customer customer) diff --git a/src/Core/Billing/Extensions/SubscriptionUpdateOptionsExtensions.cs b/src/Core/Billing/Extensions/SubscriptionUpdateOptionsExtensions.cs index d70af78fa8..22a715733b 100644 --- a/src/Core/Billing/Extensions/SubscriptionUpdateOptionsExtensions.cs +++ b/src/Core/Billing/Extensions/SubscriptionUpdateOptionsExtensions.cs @@ -22,7 +22,7 @@ public static class SubscriptionUpdateOptionsExtensions } // We might only need to check the automatic tax status. - if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country)) + if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country)) { return false; } diff --git a/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs b/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs index 88df5638c9..d00b5b46a4 100644 --- a/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs +++ b/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs @@ -22,7 +22,7 @@ public static class UpcomingInvoiceOptionsExtensions } // We might only need to check the automatic tax status. - if (!customer.HasTaxLocationVerified() && string.IsNullOrWhiteSpace(customer.Address?.Country)) + if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country)) { return false; } diff --git a/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs b/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs index 4d93c0119a..204022380d 100644 --- a/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs +++ b/src/Core/Billing/Migration/Services/Implementations/OrganizationMigrator.cs @@ -309,6 +309,7 @@ public class OrganizationMigrator( organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; organization.UsePolicies = plan.HasPolicies; organization.UseSso = plan.HasSso; + organization.UseOrganizationDomains = plan.HasOrganizationDomains; organization.UseGroups = plan.HasGroups; organization.UseEvents = plan.HasEvents; organization.UseDirectory = plan.HasDirectory; diff --git a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2019Plan.cs b/src/Core/Billing/Models/StaticStore/Plans/Enterprise2019Plan.cs index 72db7897b4..b584647a26 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2019Plan.cs +++ b/src/Core/Billing/Models/StaticStore/Plans/Enterprise2019Plan.cs @@ -26,6 +26,7 @@ public record Enterprise2019Plan : Plan Has2fa = true; HasApi = true; HasSso = true; + HasOrganizationDomains = true; HasKeyConnector = true; HasScim = true; HasResetPassword = true; diff --git a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2020Plan.cs b/src/Core/Billing/Models/StaticStore/Plans/Enterprise2020Plan.cs index 42b984e7e5..a1a6113cbc 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2020Plan.cs +++ b/src/Core/Billing/Models/StaticStore/Plans/Enterprise2020Plan.cs @@ -26,6 +26,7 @@ public record Enterprise2020Plan : Plan Has2fa = true; HasApi = true; HasSso = true; + HasOrganizationDomains = true; HasKeyConnector = true; HasScim = true; HasResetPassword = true; diff --git a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan.cs b/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan.cs index 2d498a7654..8aeca521d1 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan.cs +++ b/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan.cs @@ -26,6 +26,7 @@ public record EnterprisePlan : Plan Has2fa = true; HasApi = true; HasSso = true; + HasOrganizationDomains = true; HasKeyConnector = true; HasScim = true; HasResetPassword = true; diff --git a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan2023.cs b/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan2023.cs index 8cd8335425..dce1719a49 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan2023.cs +++ b/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan2023.cs @@ -26,6 +26,7 @@ public record Enterprise2023Plan : Plan Has2fa = true; HasApi = true; HasSso = true; + HasOrganizationDomains = true; HasKeyConnector = true; HasScim = true; HasResetPassword = true; diff --git a/src/Core/Billing/Pricing/PlanAdapter.cs b/src/Core/Billing/Pricing/PlanAdapter.cs index c38eb0501d..f719fd1e87 100644 --- a/src/Core/Billing/Pricing/PlanAdapter.cs +++ b/src/Core/Billing/Pricing/PlanAdapter.cs @@ -26,6 +26,7 @@ public record PlanAdapter : Plan Has2fa = HasFeature("2fa"); HasApi = HasFeature("api"); HasSso = HasFeature("sso"); + HasOrganizationDomains = HasFeature("organizationDomains"); HasKeyConnector = HasFeature("keyConnector"); HasScim = HasFeature("scim"); HasResetPassword = HasFeature("resetPassword"); diff --git a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs index 20f6105c2a..95df34dfd4 100644 --- a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs +++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs @@ -1,11 +1,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; @@ -35,16 +35,15 @@ public class OrganizationBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - ITaxService taxService, - IAutomaticTaxFactory automaticTaxFactory) : IOrganizationBillingService + ITaxService taxService) : IOrganizationBillingService { public async Task Finalize(OrganizationSale sale) { var (organization, customerSetup, subscriptionSetup) = sale; var customer = string.IsNullOrEmpty(organization.GatewayCustomerId) && customerSetup != null - ? await CreateCustomerAsync(organization, customerSetup) - : await subscriberService.GetCustomerOrThrow(organization, new CustomerGetOptions { Expand = ["tax", "tax_ids"] }); + ? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType) + : await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup); var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup); @@ -121,7 +120,8 @@ public class OrganizationBillingService( subscription.CurrentPeriodEnd); } - public async Task UpdatePaymentMethod( + public async Task + UpdatePaymentMethod( Organization organization, TokenizedPaymentSource tokenizedPaymentSource, TaxInformation taxInformation) @@ -151,8 +151,11 @@ public class OrganizationBillingService( private async Task CreateCustomerAsync( Organization organization, - CustomerSetup customerSetup) + CustomerSetup customerSetup, + PlanType? updatedPlanType = null) { + var planType = updatedPlanType ?? organization.PlanType; + var displayName = organization.DisplayName(); var customerCreateOptions = new CustomerCreateOptions @@ -212,13 +215,24 @@ public class OrganizationBillingService( City = customerSetup.TaxInformation.City, PostalCode = customerSetup.TaxInformation.PostalCode, State = customerSetup.TaxInformation.State, - Country = customerSetup.TaxInformation.Country, + Country = customerSetup.TaxInformation.Country }; + customerCreateOptions.Tax = new CustomerTaxOptions { ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately }; + var setNonUSBusinessUseToReverseCharge = + featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge && + planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families && + customerSetup.TaxInformation.Country != "US") + { + customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse; + } + if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId)) { var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country, @@ -399,21 +413,68 @@ public class OrganizationBillingService( TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays }; - if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + var setNonUSBusinessUseToReverseCharge = + featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge) { - var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriptionSetup.PlanType); - var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); - automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } - else + else if (customer.HasRecognizedTaxLocation()) { - subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions(); - subscriptionCreateOptions.AutomaticTax.Enabled = customer.HasBillingLocation(); + subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = + subscriptionSetup.PlanType.GetProductTier() == ProductTierType.Families || + customer.Address.Country == "US" || + customer.TaxIds.Any() + }; } return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); } + private async Task GetCustomerWhileEnsuringCorrectTaxExemptionAsync( + Organization organization, + SubscriptionSetup subscriptionSetup) + { + var customer = await subscriberService.GetCustomerOrThrow(organization, + new CustomerGetOptions { Expand = ["tax", "tax_ids"] }); + + var setNonUSBusinessUseToReverseCharge = featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (!setNonUSBusinessUseToReverseCharge || subscriptionSetup.PlanType.GetProductTier() is + not (ProductTierType.Teams or + ProductTierType.TeamsStarter or + ProductTierType.Enterprise)) + { + return customer; + } + + List expansions = ["tax", "tax_ids"]; + + customer = customer switch + { + { Address.Country: not "US", TaxExempt: not StripeConstants.TaxExempt.Reverse } => await + stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions + { + Expand = expansions, + TaxExempt = StripeConstants.TaxExempt.Reverse + }), + { Address.Country: "US", TaxExempt: StripeConstants.TaxExempt.Reverse } => await + stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions + { + Expand = expansions, + TaxExempt = StripeConstants.TaxExempt.None + }), + _ => customer + }; + + return customer; + } + private async Task IsEligibleForSelfHostAsync( Organization organization) { diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index 1b845e93f1..7496157aaa 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -3,8 +3,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Billing.Tax.Services; -using Bit.Core.Billing.Tax.Services.Implementations; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -12,7 +10,6 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Braintree; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Stripe; using Customer = Stripe.Customer; @@ -24,20 +21,18 @@ using static Utilities; public class PremiumUserBillingService( IBraintreeGateway braintreeGateway, - IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - IUserRepository userRepository, - [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy automaticTaxStrategy) : IPremiumUserBillingService + IUserRepository userRepository) : IPremiumUserBillingService { public async Task Credit(User user, decimal amount) { var customer = await subscriberService.GetCustomer(user); - // Negative credit represents a balance and all Stripe denomination is in cents. + // Negative credit represents a balance, and all Stripe denomination is in cents. var credit = (long)(amount * -100); if (customer == null) @@ -184,7 +179,7 @@ public class PremiumUserBillingService( City = customerSetup.TaxInformation.City, PostalCode = customerSetup.TaxInformation.PostalCode, State = customerSetup.TaxInformation.State, - Country = customerSetup.TaxInformation.Country, + Country = customerSetup.TaxInformation.Country }, Description = user.Name, Email = user.Email, @@ -324,6 +319,10 @@ public class PremiumUserBillingService( var subscriptionCreateOptions = new SubscriptionCreateOptions { + AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }, CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, @@ -337,18 +336,6 @@ public class PremiumUserBillingService( OffSession = true }; - if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) - { - automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer); - } - else - { - subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported, - }; - } - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); if (usingPayPal) @@ -380,7 +367,7 @@ public class PremiumUserBillingService( City = taxInformation.City, PostalCode = taxInformation.PostalCode, State = taxInformation.State, - Country = taxInformation.Country, + Country = taxInformation.Country }, Expand = ["tax"], Tax = new CustomerTaxOptions diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 10247cdf92..75a1bf76ec 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -1,7 +1,10 @@ -using Bit.Core.Billing.Caches; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Services; using Bit.Core.Entities; @@ -28,8 +31,7 @@ public class SubscriberService( ILogger logger, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ITaxService taxService, - IAutomaticTaxFactory automaticTaxFactory) : ISubscriberService + ITaxService taxService) : ISubscriberService { public async Task CancelSubscription( ISubscriber subscriber, @@ -128,7 +130,7 @@ public class SubscriberService( [subscriber.BraintreeCloudRegionField()] = globalSettings.BaseServiceUri.CloudRegion }, Email = subscriber.BillingEmailAddress(), - PaymentMethodNonce = paymentMethodNonce, + PaymentMethodNonce = paymentMethodNonce }); if (customerResult.IsSuccess()) @@ -482,7 +484,7 @@ public class SubscriberService( var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First(); - // Find the customer's existing setup intents that should be cancelled. + // Find the customer's existing setup intents that should be canceled. var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer) .Where(si => si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action"); @@ -519,7 +521,7 @@ public class SubscriberService( await stripeAdapter.PaymentMethodAttachAsync(token, new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId }); - // Find the customer's existing setup intents that should be cancelled. + // Find the customer's existing setup intents that should be canceled. var existingSetupIntentsForCustomer = (await getExistingSetupIntentsForCustomer) .Where(si => si.Status is "requires_payment_method" or "requires_confirmation" or "requires_action"); @@ -637,7 +639,8 @@ public class SubscriberService( logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", taxInformation.Country, taxInformation.TaxId); - throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError"); + + throw new BadRequestException("billingTaxIdTypeInferenceError"); } } @@ -654,53 +657,84 @@ public class SubscriberService( logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", taxInformation.TaxId, taxInformation.Country); - throw new Exceptions.BadRequestException("billingInvalidTaxIdError"); + + throw new BadRequestException("billingInvalidTaxIdError"); + default: logger.LogError(e, "Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.", taxInformation.TaxId, taxInformation.Country, customer.Id); - throw new Exceptions.BadRequestException("billingTaxIdCreationError"); + + throw new BadRequestException("billingTaxIdCreationError"); } } } - if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + var subscription = + customer.Subscriptions.First(subscription => subscription.Id == subscriber.GatewaySubscriptionId); + + var isBusinessUseSubscriber = subscriber switch { - if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + Organization organization => organization.PlanType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families, + Provider => true, + _ => false + }; + + var setNonUSBusinessUseToReverseCharge = + featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge && isBusinessUseSubscriber) + { + switch (customer) { - var subscriptionGetOptions = new SubscriptionGetOptions + case { - Expand = ["customer.tax", "customer.tax_ids"] - }; - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); - var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); - var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters); - var automaticTaxOptions = automaticTaxStrategy.GetUpdateOptions(subscription); - if (automaticTaxOptions?.AutomaticTax?.Enabled != null) + Address.Country: not "US", + TaxExempt: not StripeConstants.TaxExempt.Reverse + }: + await stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); + break; + case { - await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, automaticTaxOptions); - } + Address.Country: "US", + TaxExempt: StripeConstants.TaxExempt.Reverse + }: + await stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.None }); + break; } - } - else - { - if (SubscriberIsEligibleForAutomaticTax(subscriber, customer)) + + if (!subscription.AutomaticTax.Enabled) { - await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, + await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } }); } + } + else + { + var automaticTaxShouldBeEnabled = subscriber switch + { + User => true, + Organization organization => organization.PlanType.GetProductTier() == ProductTierType.Families || + customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false), + Provider => customer.Address.Country == "US" || (customer.TaxIds?.Any() ?? false), + _ => false + }; - return; - - bool SubscriberIsEligibleForAutomaticTax(ISubscriber localSubscriber, Customer localCustomer) - => !string.IsNullOrEmpty(localSubscriber.GatewaySubscriptionId) && - (localCustomer.Subscriptions?.Any(sub => sub.Id == localSubscriber.GatewaySubscriptionId && !sub.AutomaticTax.Enabled) ?? false) && - localCustomer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; + if (automaticTaxShouldBeEnabled && !subscription.AutomaticTax.Enabled) + { + await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } } } diff --git a/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs b/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs index 310aced130..6affc57354 100644 --- a/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs +++ b/src/Core/Billing/Tax/Services/Implementations/BusinessUseAutomaticTaxStrategy.cs @@ -76,7 +76,7 @@ public class BusinessUseAutomaticTaxStrategy(IFeatureService featureService) : I private bool ShouldBeEnabled(Customer customer) { - if (!customer.HasTaxLocationVerified()) + if (!customer.HasRecognizedTaxLocation()) { return false; } diff --git a/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs b/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs index e89fc6a3b3..615222259e 100644 --- a/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs +++ b/src/Core/Billing/Tax/Services/Implementations/PersonalUseAutomaticTaxStrategy.cs @@ -59,6 +59,6 @@ public class PersonalUseAutomaticTaxStrategy(IFeatureService featureService) : I private static bool ShouldBeEnabled(Customer customer) { - return customer.HasTaxLocationVerified(); + return customer.HasRecognizedTaxLocation(); } } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 707001ddcc..694521c14e 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -143,13 +143,13 @@ public static class FeatureFlagKeys public const string UsePricingService = "use-pricing-service"; public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features"; public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method"; - public const string PM19147_AutomaticTaxImprovements = "pm-19147-automatic-tax-improvements"; public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; public const string PM18770_EnableOrganizationBusinessUnitConversion = "pm-18770-enable-organization-business-unit-conversion"; public const string PM199566_UpdateMSPToChargeAutomatically = "pm-199566-update-msp-to-charge-automatically"; public const string PM19956_RequireProviderPaymentMethodDuringSetup = "pm-19956-require-provider-payment-method-during-setup"; public const string UseOrganizationWarningsService = "use-organization-warnings-service"; public const string PM20322_AllowTrialLength0 = "pm-20322-allow-trial-length-0"; + public const string PM21092_SetNonUSBusinessUseToReverseCharge = "pm-21092-set-non-us-business-use-to-reverse-charge"; /* Data Insights and Reporting Team */ public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application"; diff --git a/src/Core/Models/Business/OrganizationLicense.cs b/src/Core/Models/Business/OrganizationLicense.cs index a201da3847..e8c04b1277 100644 --- a/src/Core/Models/Business/OrganizationLicense.cs +++ b/src/Core/Models/Business/OrganizationLicense.cs @@ -84,6 +84,7 @@ public class OrganizationLicense : ILicense SmSeats = org.SmSeats; SmServiceAccounts = org.SmServiceAccounts; UseRiskInsights = org.UseRiskInsights; + UseOrganizationDomains = org.UseOrganizationDomains; // Deprecated. Left for backwards compatibility with old license versions. LimitCollectionCreationDeletion = org.LimitCollectionCreation || org.LimitCollectionDeletion; @@ -195,10 +196,10 @@ public class OrganizationLicense : ILicense /// Intentionally set one version behind to allow self hosted users some time to update before /// getting out of date license errors /// - public const int CurrentLicenseFileVersion = 14; + public const int CurrentLicenseFileVersion = 15; private bool ValidLicenseVersion { - get => Version is >= 1 and <= 15; + get => Version is >= 1 and <= 16; } public byte[] GetDataBytes(bool forHash = false) @@ -244,6 +245,8 @@ public class OrganizationLicense : ILicense (Version >= 14 || !p.Name.Equals(nameof(LimitCollectionCreationDeletion))) && // AllowAdminAccessToAllCollectionItems was added in Version 15 (Version >= 15 || !p.Name.Equals(nameof(AllowAdminAccessToAllCollectionItems))) && + // UseOrganizationDomains was added in Version 16 + (Version >= 16 || !p.Name.Equals(nameof(UseOrganizationDomains))) && ( !forHash || ( @@ -252,7 +255,10 @@ public class OrganizationLicense : ILicense !p.Name.Equals(nameof(Refresh)) ) ) && - !p.Name.Equals(nameof(UseRiskInsights))) + // any new fields added need to be added here so that they're ignored + !p.Name.Equals(nameof(UseRiskInsights)) && + !p.Name.Equals(nameof(UseAdminSponsoredFamilies)) && + !p.Name.Equals(nameof(UseOrganizationDomains))) .OrderBy(p => p.Name) .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") .Aggregate((c, n) => $"{c}|{n}"); @@ -583,6 +589,11 @@ public class OrganizationLicense : ILicense * validation. */ + if (valid && Version >= 16) + { + valid = organization.UseOrganizationDomains; + } + return valid; } diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index 3fdb829cf4..af96b88ee6 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Models; using Bit.Core.Billing.Tax.Requests; using Bit.Core.Billing.Tax.Responses; using Bit.Core.Entities; -using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; @@ -30,8 +29,6 @@ public interface IPaymentService Task AdjustServiceAccountsAsync(Organization organization, Plan plan, int additionalServiceAccounts); Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false); Task ReinstateSubscriptionAsync(ISubscriber subscriber); - Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, - string paymentToken, TaxInfo taxInfo = null); Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount); Task GetBillingAsync(ISubscriber subscriber); Task GetBillingHistoryAsync(ISubscriber subscriber); diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 65c0525535..34be6d59c5 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1,13 +1,13 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Models.Business; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Contracts; -using Bit.Core.Billing.Tax.Models; using Bit.Core.Billing.Tax.Requests; using Bit.Core.Billing.Tax.Responses; using Bit.Core.Billing.Tax.Services; @@ -38,7 +38,6 @@ public class StripePaymentService : IPaymentService private readonly IGlobalSettings _globalSettings; private readonly IFeatureService _featureService; private readonly ITaxService _taxService; - private readonly ISubscriberService _subscriberService; private readonly IPricingClient _pricingClient; private readonly IAutomaticTaxFactory _automaticTaxFactory; private readonly IAutomaticTaxStrategy _personalUseTaxStrategy; @@ -51,7 +50,6 @@ public class StripePaymentService : IPaymentService IGlobalSettings globalSettings, IFeatureService featureService, ITaxService taxService, - ISubscriberService subscriberService, IPricingClient pricingClient, IAutomaticTaxFactory automaticTaxFactory, [FromKeyedServices(AutomaticTaxFactory.PersonalUse)] IAutomaticTaxStrategy personalUseTaxStrategy) @@ -63,7 +61,6 @@ public class StripePaymentService : IPaymentService _globalSettings = globalSettings; _featureService = featureService; _taxService = taxService; - _subscriberService = subscriberService; _pricingClient = pricingClient; _automaticTaxFactory = automaticTaxFactory; _personalUseTaxStrategy = personalUseTaxStrategy; @@ -136,15 +133,68 @@ public class StripePaymentService : IPaymentService if (subscriptionUpdate is CompleteSubscriptionUpdate) { - if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) + var setNonUSBusinessUseToReverseCharge = + _featureService.IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge); + + if (setNonUSBusinessUseToReverseCharge) { - var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, updatedItemOptions.Select(x => x.Plan ?? x.Price)); - var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); - automaticTaxStrategy.SetUpdateOptions(subUpdateOptions, sub); + if (sub.Customer is + { + Address.Country: not "US", + TaxExempt: not StripeConstants.TaxExempt.Reverse + }) + { + await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId, + new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); + } + + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } - else + else if (sub.Customer.HasRecognizedTaxLocation()) { - subUpdateOptions.EnableAutomaticTax(sub.Customer, sub); + switch (subscriber) + { + case User: + { + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; + break; + } + case Organization: + { + if (sub.Customer.Address.Country == "US") + { + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; + } + else + { + var familyPriceIds = (await Task.WhenAll( + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019), + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually))) + .Select(plan => plan.PasswordManager.StripePlanId); + + var updateIsForPersonalUse = updatedItemOptions + .Select(option => option.Price) + .Intersect(familyPriceIds) + .Any(); + + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = updateIsForPersonalUse || sub.Customer.TaxIds.Any() + }; + } + + break; + } + case Provider: + { + subUpdateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = sub.Customer.Address.Country == "US" || + sub.Customer.TaxIds.Any() + }; + break; + } + } } } @@ -202,7 +252,7 @@ public class StripePaymentService : IPaymentService } else if (!invoice.Paid) { - // Pay invoice with no charge to customer this completes the invoice immediately without waiting the scheduled 1h + // Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); paymentIntentClientSecret = null; } @@ -585,309 +635,6 @@ public class StripePaymentService : IPaymentService } } - public async Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, - string paymentToken, TaxInfo taxInfo = null) - { - if (subscriber == null) - { - throw new ArgumentNullException(nameof(subscriber)); - } - - if (subscriber.Gateway.HasValue && subscriber.Gateway.Value != GatewayType.Stripe) - { - throw new GatewayException("Switching from one payment type to another is not supported. " + - "Contact us for assistance."); - } - - var createdCustomer = false; - Braintree.Customer braintreeCustomer = null; - string stipeCustomerSourceToken = null; - string stipeCustomerPaymentMethodId = null; - var stripeCustomerMetadata = new Dictionary - { - { "region", _globalSettings.BaseServiceUri.CloudRegion } - }; - var stripePaymentMethod = paymentMethodType is PaymentMethodType.Card or PaymentMethodType.BankAccount; - - Customer customer = null; - - if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - var options = new CustomerGetOptions { Expand = ["sources", "tax", "subscriptions"] }; - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, options); - if (customer.Metadata?.Any() ?? false) - { - stripeCustomerMetadata = customer.Metadata; - } - } - - var hadBtCustomer = stripeCustomerMetadata.ContainsKey("btCustomerId"); - if (stripePaymentMethod) - { - if (paymentToken.StartsWith("pm_")) - { - stipeCustomerPaymentMethodId = paymentToken; - } - else - { - stipeCustomerSourceToken = paymentToken; - } - } - else if (paymentMethodType == PaymentMethodType.PayPal) - { - if (hadBtCustomer) - { - var pmResult = await _btGateway.PaymentMethod.CreateAsync(new Braintree.PaymentMethodRequest - { - CustomerId = stripeCustomerMetadata["btCustomerId"], - PaymentMethodNonce = paymentToken - }); - - if (pmResult.IsSuccess()) - { - var customerResult = await _btGateway.Customer.UpdateAsync( - stripeCustomerMetadata["btCustomerId"], new Braintree.CustomerRequest - { - DefaultPaymentMethodToken = pmResult.Target.Token - }); - - if (customerResult.IsSuccess() && customerResult.Target.PaymentMethods.Length > 0) - { - braintreeCustomer = customerResult.Target; - } - else - { - await _btGateway.PaymentMethod.DeleteAsync(pmResult.Target.Token); - hadBtCustomer = false; - } - } - else - { - hadBtCustomer = false; - } - } - - if (!hadBtCustomer) - { - var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest - { - PaymentMethodNonce = paymentToken, - Email = subscriber.BillingEmailAddress(), - Id = subscriber.BraintreeCustomerIdPrefix() + subscriber.Id.ToString("N").ToLower() + - Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false), - CustomFields = new Dictionary - { - [subscriber.BraintreeIdField()] = subscriber.Id.ToString(), - [subscriber.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion - } - }); - - if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0) - { - throw new GatewayException("Failed to create PayPal customer record."); - } - - braintreeCustomer = customerResult.Target; - } - } - else - { - throw new GatewayException("Payment method is not supported at this time."); - } - - if (stripeCustomerMetadata.ContainsKey("btCustomerId")) - { - if (braintreeCustomer?.Id != stripeCustomerMetadata["btCustomerId"]) - { - stripeCustomerMetadata["btCustomerId_old"] = stripeCustomerMetadata["btCustomerId"]; - } - - stripeCustomerMetadata["btCustomerId"] = braintreeCustomer?.Id; - } - else if (!string.IsNullOrWhiteSpace(braintreeCustomer?.Id)) - { - stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); - } - - try - { - if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber)) - { - taxInfo.TaxIdType = taxInfo.TaxIdType ?? - _taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber); - } - - if (customer == null) - { - customer = await _stripeAdapter.CustomerCreateAsync(new CustomerCreateOptions - { - Description = subscriber.BillingName(), - Email = subscriber.BillingEmailAddress(), - Metadata = stripeCustomerMetadata, - Source = stipeCustomerSourceToken, - PaymentMethod = stipeCustomerPaymentMethodId, - InvoiceSettings = new CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = stipeCustomerPaymentMethodId, - CustomFields = - [ - new CustomerInvoiceSettingsCustomFieldOptions() - { - Name = subscriber.SubscriberType(), - Value = subscriber.GetFormattedInvoiceName() - } - - ] - }, - Address = taxInfo == null ? null : new AddressOptions - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState - }, - TaxIdData = string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber) - ? [] - : [ - new CustomerTaxIdDataOptions - { - Type = taxInfo.TaxIdType, - Value = taxInfo.TaxIdNumber - } - ], - Expand = ["sources", "tax", "subscriptions"], - }); - - subscriber.Gateway = GatewayType.Stripe; - subscriber.GatewayCustomerId = customer.Id; - createdCustomer = true; - } - - if (!createdCustomer) - { - string defaultSourceId = null; - string defaultPaymentMethodId = null; - if (stripePaymentMethod) - { - if (!string.IsNullOrWhiteSpace(stipeCustomerSourceToken) && paymentToken.StartsWith("btok_")) - { - var bankAccount = await _stripeAdapter.BankAccountCreateAsync(customer.Id, new BankAccountCreateOptions - { - Source = paymentToken - }); - defaultSourceId = bankAccount.Id; - } - else if (!string.IsNullOrWhiteSpace(stipeCustomerPaymentMethodId)) - { - await _stripeAdapter.PaymentMethodAttachAsync(stipeCustomerPaymentMethodId, - new PaymentMethodAttachOptions { Customer = customer.Id }); - defaultPaymentMethodId = stipeCustomerPaymentMethodId; - } - } - - if (customer.Sources != null) - { - foreach (var source in customer.Sources.Where(s => s.Id != defaultSourceId)) - { - if (source is BankAccount) - { - await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); - } - else if (source is Card) - { - await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id); - } - } - } - - var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging(new PaymentMethodListOptions - { - Customer = customer.Id, - Type = "card" - }); - foreach (var cardMethod in cardPaymentMethods.Where(m => m.Id != defaultPaymentMethodId)) - { - await _stripeAdapter.PaymentMethodDetachAsync(cardMethod.Id, new PaymentMethodDetachOptions()); - } - - await _subscriberService.UpdateTaxInformation(subscriber, TaxInformation.From(taxInfo)); - - customer = await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions - { - Metadata = stripeCustomerMetadata, - DefaultSource = defaultSourceId, - InvoiceSettings = new CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = defaultPaymentMethodId, - CustomFields = - [ - new CustomerInvoiceSettingsCustomFieldOptions() - { - Name = subscriber.SubscriberType(), - Value = subscriber.GetFormattedInvoiceName() - } - ] - }, - Expand = ["tax", "subscriptions"] - }); - } - - if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements)) - { - if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) - { - var subscriptionGetOptions = new SubscriptionGetOptions - { - Expand = ["customer.tax", "customer.tax_ids"] - }; - var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); - - var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscriber, subscription.Items.Select(x => x.Price.Id)); - var automaticTaxStrategy = await _automaticTaxFactory.CreateAsync(automaticTaxParameters); - var subscriptionUpdateOptions = automaticTaxStrategy.GetUpdateOptions(subscription); - - if (subscriptionUpdateOptions != null) - { - _ = await _stripeAdapter.SubscriptionUpdateAsync( - subscriber.GatewaySubscriptionId, - subscriptionUpdateOptions); - } - } - } - else - { - if (!string.IsNullOrEmpty(subscriber.GatewaySubscriptionId) && - customer.Subscriptions.Any(sub => - sub.Id == subscriber.GatewaySubscriptionId && - !sub.AutomaticTax.Enabled) && - customer.HasTaxLocationVerified()) - { - var subscriptionUpdateOptions = new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, - DefaultTaxRates = [] - }; - - _ = await _stripeAdapter.SubscriptionUpdateAsync( - subscriber.GatewaySubscriptionId, - subscriptionUpdateOptions); - } - } - } - catch - { - if (braintreeCustomer != null && !hadBtCustomer) - { - await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); - } - throw; - } - - return createdCustomer; - } - public async Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount) { Customer customer = null; @@ -1018,7 +765,7 @@ public class StripePaymentService : IPaymentService var address = customer.Address; var taxId = customer.TaxIds?.FirstOrDefault(); - // Line1 is required, so if missing we're using the subscriber name + // Line1 is required, so if missing we're using the subscriber name, // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 if (address != null && string.IsNullOrWhiteSpace(address.Line1)) { diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index 151ff38aa5..76520b4085 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -1341,9 +1341,7 @@ public class UserService : UserManager, IUserService, IDisposable var organizationsWithVerifiedUserEmailDomain = await _organizationRepository.GetByVerifiedUserEmailDomainAsync(userId); // Organizations must be enabled and able to have verified domains. - // TODO: Replace "UseSso" with a new organization ability like "UseOrganizationDomains" (PM-11622). - // Verified domains were tied to SSO, so we currently check the "UseSso" organization ability. - return organizationsWithVerifiedUserEmailDomain.Where(organization => organization is { Enabled: true, UseSso: true }); + return organizationsWithVerifiedUserEmailDomain.Where(organization => organization is { Enabled: true, UseOrganizationDomains: true }); } /// diff --git a/src/Identity/Models/Response/Accounts/RegisterFinishResponseModel.cs b/src/Identity/Models/Response/Accounts/RegisterFinishResponseModel.cs index d7c7b94366..564150ab30 100644 --- a/src/Identity/Models/Response/Accounts/RegisterFinishResponseModel.cs +++ b/src/Identity/Models/Response/Accounts/RegisterFinishResponseModel.cs @@ -6,5 +6,12 @@ public class RegisterFinishResponseModel : ResponseModel { public RegisterFinishResponseModel() : base("registerFinish") - { } + { + // We are setting this to an empty string so that old mobile clients don't break, as they reqiure a non-null value. + // This will be cleaned up in https://bitwarden.atlassian.net/browse/PM-21720. + CaptchaBypassToken = string.Empty; + } + + public string CaptchaBypassToken { get; set; } + } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/GetOrganizationUsersClaimedStatusQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/GetOrganizationUsersClaimedStatusQueryTests.cs index fd6d827791..85dc643022 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/GetOrganizationUsersClaimedStatusQueryTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/GetOrganizationUsersClaimedStatusQueryTests.cs @@ -25,13 +25,13 @@ public class GetOrganizationUsersClaimedStatusQueryTests } [Theory, BitAutoData] - public async Task GetUsersOrganizationManagementStatusAsync_WithUseSsoEnabled_Success( + public async Task GetUsersOrganizationManagementStatusAsync_WithUseOrganizationDomainsEnabled_Success( Organization organization, ICollection usersWithClaimedDomain, SutProvider sutProvider) { organization.Enabled = true; - organization.UseSso = true; + organization.UseOrganizationDomains = true; var userIdWithoutClaimedDomain = Guid.NewGuid(); var userIdsToCheck = usersWithClaimedDomain.Select(u => u.Id).Concat(new List { userIdWithoutClaimedDomain }).ToList(); @@ -51,13 +51,13 @@ public class GetOrganizationUsersClaimedStatusQueryTests } [Theory, BitAutoData] - public async Task GetUsersOrganizationManagementStatusAsync_WithUseSsoDisabled_ReturnsAllFalse( + public async Task GetUsersOrganizationManagementStatusAsync_WithUseOrganizationDomainsDisabled_ReturnsAllFalse( Organization organization, ICollection usersWithClaimedDomain, SutProvider sutProvider) { organization.Enabled = true; - organization.UseSso = false; + organization.UseOrganizationDomains = false; var userIdWithoutClaimedDomain = Guid.NewGuid(); var userIdsToCheck = usersWithClaimedDomain.Select(u => u.Id).Concat(new List { userIdWithoutClaimedDomain }).ToList(); diff --git a/test/Core.Test/Auth/Services/EmergencyAccessServiceTests.cs b/test/Core.Test/Auth/Services/EmergencyAccessServiceTests.cs index 6c2352ca00..006515aafd 100644 --- a/test/Core.Test/Auth/Services/EmergencyAccessServiceTests.cs +++ b/test/Core.Test/Auth/Services/EmergencyAccessServiceTests.cs @@ -1,11 +1,17 @@ -using Bit.Core.Auth.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; +using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Services; using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Tokens; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -17,27 +23,21 @@ namespace Bit.Core.Test.Auth.Services; public class EmergencyAccessServiceTests { [Theory, BitAutoData] - public async Task SaveAsync_PremiumCannotUpdate( - SutProvider sutProvider, User savingUser) + public async Task InviteAsync_UserWithOutPremium_ThrowsBadRequest( + SutProvider sutProvider, User invitingUser, string email, int waitTime) { - savingUser.Premium = false; - var emergencyAccess = new EmergencyAccess - { - Type = EmergencyAccessType.Takeover, - GrantorId = savingUser.Id, - }; - - sutProvider.GetDependency().GetUserByIdAsync(savingUser.Id).Returns(savingUser); + sutProvider.GetDependency().CanAccessPremium(invitingUser).Returns(false); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); + () => sutProvider.Sut.InviteAsync(invitingUser, email, EmergencyAccessType.Takeover, waitTime)); Assert.Contains("Not a premium user.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs().CreateAsync(default); } [Theory, BitAutoData] - public async Task InviteAsync_UserWithKeyConnectorCannotUseTakeover( + public async Task InviteAsync_UserWithKeyConnector_ThrowsBadRequest( SutProvider sutProvider, User invitingUser, string email, int waitTime) { invitingUser.UsesKeyConnector = true; @@ -47,11 +47,461 @@ public class EmergencyAccessServiceTests () => sutProvider.Sut.InviteAsync(invitingUser, email, EmergencyAccessType.Takeover, waitTime)); Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs().CreateAsync(default); + } + + [Theory] + [BitAutoData(EmergencyAccessType.Takeover)] + [BitAutoData(EmergencyAccessType.View)] + public async Task InviteAsync_ReturnsEmergencyAccessObject( + EmergencyAccessType accessType, SutProvider sutProvider, User invitingUser, string email, int waitTime) + { + sutProvider.GetDependency().CanAccessPremium(invitingUser).Returns(true); + + var result = await sutProvider.Sut.InviteAsync(invitingUser, email, accessType, waitTime); + + Assert.NotNull(result); + Assert.Equal(accessType, result.Type); + Assert.Equal(invitingUser.Id, result.GrantorId); + Assert.Equal(email, result.Email); + Assert.Equal(EmergencyAccessStatusType.Invited, result.Status); + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(Arg.Any()); + sutProvider.GetDependency>() + .Received(1) + .Protect(Arg.Any()); + await sutProvider.GetDependency() + .Received(1) + .SendEmergencyAccessInviteEmailAsync(Arg.Any(), Arg.Any(), Arg.Any()); } [Theory, BitAutoData] - public async Task ConfirmUserAsync_UserWithKeyConnectorCannotUseTakeover( + public async Task GetAsync_EmergencyAccessNull_ThrowsBadRequest( + SutProvider sutProvider, User user) + { + EmergencyAccessDetails emergencyAccess = null; + sutProvider.GetDependency() + .GetDetailsByIdGrantorIdAsync(Arg.Any(), Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.GetAsync(new Guid(), user.Id)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task ResendInviteAsync_EmergencyAccessNull_ThrowsBadRequest( + SutProvider sutProvider, + User invitingUser, + Guid emergencyAccessId) + { + EmergencyAccess emergencyAccess = null; + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ResendInviteAsync(invitingUser, emergencyAccessId)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendEmergencyAccessInviteEmailAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task ResendInviteAsync_InvitingUserIdNotGrantorUserId_ThrowsBadRequest( + SutProvider sutProvider, + User invitingUser, + Guid emergencyAccessId) + { + var emergencyAccess = new EmergencyAccess + { + Status = EmergencyAccessStatusType.Invited, + GrantorId = Guid.NewGuid(), + Type = EmergencyAccessType.Takeover, + }; ; + + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ResendInviteAsync(invitingUser, emergencyAccessId)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendEmergencyAccessInviteEmailAsync(default, default, default); + } + + [Theory] + [BitAutoData(EmergencyAccessStatusType.Accepted)] + [BitAutoData(EmergencyAccessStatusType.Confirmed)] + [BitAutoData(EmergencyAccessStatusType.RecoveryInitiated)] + [BitAutoData(EmergencyAccessStatusType.RecoveryApproved)] + public async Task ResendInviteAsync_EmergencyAccessStatusInvalid_ThrowsBadRequest( + EmergencyAccessStatusType statusType, + SutProvider sutProvider, + User invitingUser, + Guid emergencyAccessId) + { + var emergencyAccess = new EmergencyAccess + { + Status = statusType, + GrantorId = invitingUser.Id, + Type = EmergencyAccessType.Takeover, + }; + + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ResendInviteAsync(invitingUser, emergencyAccessId)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendEmergencyAccessInviteEmailAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task ResendInviteAsync_SendsInviteAsync( + SutProvider sutProvider, + User invitingUser, + Guid emergencyAccessId) + { + var emergencyAccess = new EmergencyAccess + { + Status = EmergencyAccessStatusType.Invited, + GrantorId = invitingUser.Id, + Type = EmergencyAccessType.Takeover, + }; ; + + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + + await sutProvider.Sut.ResendInviteAsync(invitingUser, emergencyAccessId); + sutProvider.GetDependency>() + .Received(1) + .Protect(Arg.Any()); + await sutProvider.GetDependency() + .Received(1) + .SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUser.Name, Arg.Any()); + } + + [Theory, BitAutoData] + public async Task AcceptUserAsync_EmergencyAccessNull_ThrowsBadRequest( + SutProvider sutProvider, User acceptingUser, string token) + { + EmergencyAccess emergencyAccess = null; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(new Guid(), acceptingUser, token, sutProvider.GetDependency())); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task AcceptUserAsync_CannotUnprotectToken_ThrowsBadRequest( + SutProvider sutProvider, + User acceptingUser, + EmergencyAccess emergencyAccess, + string token) + { + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency>() + .TryUnprotect(token, out Arg.Any()) + .Returns(false); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(emergencyAccess.Id, acceptingUser, token, sutProvider.GetDependency())); + + Assert.Contains("Invalid token.", exception.Message); + } + + [Theory, BitAutoData] + public async Task AcceptUserAsync_TokenDataInvalid_ThrowsBadRequest( + SutProvider sutProvider, + User acceptingUser, + EmergencyAccess emergencyAccess, + EmergencyAccess wrongEmergencyAccess, + string token) + { + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency>() + .TryUnprotect(token, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new EmergencyAccessInviteTokenable(wrongEmergencyAccess, 1); + return true; + }); + + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(emergencyAccess.Id, acceptingUser, token, sutProvider.GetDependency())); + + Assert.Contains("Invalid token.", exception.Message); + } + + [Theory, BitAutoData] + public async Task AcceptUserAsync_AcceptedStatus_ThrowsBadRequest( + SutProvider sutProvider, + User acceptingUser, + EmergencyAccess emergencyAccess, + string token) + { + emergencyAccess.Status = EmergencyAccessStatusType.Accepted; + emergencyAccess.Email = acceptingUser.Email; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency>() + .TryUnprotect(token, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new EmergencyAccessInviteTokenable(emergencyAccess, 1); + return true; + }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(emergencyAccess.Id, acceptingUser, token, sutProvider.GetDependency())); + + Assert.Contains("Invitation already accepted. You will receive an email when the grantor confirms you as an emergency access contact.", exception.Message); + } + + [Theory, BitAutoData] + public async Task AcceptUserAsync_NotInvitedStatus_ThrowsBadRequest( + SutProvider sutProvider, + User acceptingUser, + EmergencyAccess emergencyAccess, + string token) + { + emergencyAccess.Status = EmergencyAccessStatusType.Confirmed; + emergencyAccess.Email = acceptingUser.Email; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency>() + .TryUnprotect(token, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new EmergencyAccessInviteTokenable(emergencyAccess, 1); + return true; + }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(emergencyAccess.Id, acceptingUser, token, sutProvider.GetDependency())); + + Assert.Contains("Invitation already accepted.", exception.Message); + } + + [Theory(Skip = "Code not reachable, Tokenable checks email match in IsValid()"), BitAutoData] + public async Task AcceptUserAsync_EmergencyAccessEmailDoesNotMatch_ThrowsBadRequest( + SutProvider sutProvider, + User acceptingUser, + EmergencyAccess emergencyAccess, + string token) + { + emergencyAccess.Status = EmergencyAccessStatusType.Invited; + emergencyAccess.Email = acceptingUser.Email; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency>() + .TryUnprotect(token, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new EmergencyAccessInviteTokenable(emergencyAccess, 1); + return true; + }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(emergencyAccess.Id, acceptingUser, token, sutProvider.GetDependency())); + + Assert.Contains("User email does not match invite.", exception.Message); + } + + [Theory, BitAutoData] + public async Task AcceptUserAsync_ReplaceEmergencyAccess_SendsEmail_Success( + SutProvider sutProvider, + User acceptingUser, + User invitingUser, + EmergencyAccess emergencyAccess, + string token) + { + emergencyAccess.Status = EmergencyAccessStatusType.Invited; + emergencyAccess.Email = acceptingUser.Email; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency() + .GetUserByIdAsync(Arg.Any()) + .Returns(invitingUser); + + sutProvider.GetDependency>() + .TryUnprotect(token, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new EmergencyAccessInviteTokenable(emergencyAccess, 1); + return true; + }); + + await sutProvider.Sut.AcceptUserAsync(emergencyAccess.Id, acceptingUser, token, sutProvider.GetDependency()); + + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.Accepted)); + + await sutProvider.GetDependency() + .Received(1) + .SendEmergencyAccessAcceptedEmailAsync(acceptingUser.Email, invitingUser.Email); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_EmergencyAccessNull_ThrowsBadRequest( + SutProvider sutProvider, + User invitingUser, + EmergencyAccess emergencyAccess) + { + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((EmergencyAccess)null); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(emergencyAccess.Id, invitingUser.Id)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_EmergencyAccessGrantorIdNotEqual_ThrowsBadRequest( + SutProvider sutProvider, + User invitingUser, + EmergencyAccess emergencyAccess) + { + emergencyAccess.GrantorId = Guid.NewGuid(); + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(emergencyAccess.Id, invitingUser.Id)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_EmergencyAccessGranteeIdNotEqual_ThrowsBadRequest( + SutProvider sutProvider, + User invitingUser, + EmergencyAccess emergencyAccess) + { + emergencyAccess.GranteeId = Guid.NewGuid(); + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(emergencyAccess.Id, invitingUser.Id)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_EmergencyAccessIsDeleted_Success( + SutProvider sutProvider, + User user, + EmergencyAccess emergencyAccess) + { + emergencyAccess.GranteeId = user.Id; + emergencyAccess.GrantorId = user.Id; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + await sutProvider.Sut.DeleteAsync(emergencyAccess.Id, user.Id); + + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(emergencyAccess); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_EmergencyAccessNull_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + string key, + User grantorUser) + { + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((EmergencyAccess)null); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(emergencyAccess.Id, key, grantorUser.Id)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_EmergencyAccessStatusIsNotAccepted_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + string key, + User grantorUser) + { + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; + sutProvider.GetDependency() + .GetByIdAsync(emergencyAccess.Id) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(emergencyAccess.Id, key, grantorUser.Id)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_EmergencyAccessGrantorIdNotEqualToConfirmingUserId_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + string key, + User grantorUser) + { + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(emergencyAccess.Id, key, grantorUser.Id)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_UserWithKeyConnectorCannotUseTakeover_ThrowsBadRequest( SutProvider sutProvider, User confirmingUser, string key) { confirmingUser.UsesKeyConnector = true; @@ -62,8 +512,13 @@ public class EmergencyAccessServiceTests Type = EmergencyAccessType.Takeover, }; - sutProvider.GetDependency().GetByIdAsync(confirmingUser.Id).Returns(confirmingUser); - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + sutProvider.GetDependency() + .GetByIdAsync(confirmingUser.Id) + .Returns(confirmingUser); + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.ConfirmUserAsync(new Guid(), key, confirmingUser.Id)); @@ -73,29 +528,210 @@ public class EmergencyAccessServiceTests } [Theory, BitAutoData] - public async Task SaveAsync_UserWithKeyConnectorCannotUseTakeover( + public async Task ConfirmUserAsync_ConfirmsAndReplacesEmergencyAccess_Success( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + string key, + User grantorUser, + User granteeUser) + { + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.Accepted; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency() + .GetByIdAsync(grantorUser.Id) + .Returns(grantorUser); + + sutProvider.GetDependency() + .GetByIdAsync(emergencyAccess.GranteeId.Value) + .Returns(granteeUser); + + await sutProvider.Sut.ConfirmUserAsync(emergencyAccess.Id, key, grantorUser.Id); + + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.Confirmed)); + + await sutProvider.GetDependency() + .Received(1) + .SendEmergencyAccessConfirmedEmailAsync(grantorUser.Name, granteeUser.Email); + } + + [Theory, BitAutoData] + public async Task SaveAsync_PremiumCannotUpdate_ThrowsBadRequest( SutProvider sutProvider, User savingUser) { - savingUser.UsesKeyConnector = true; var emergencyAccess = new EmergencyAccess { Type = EmergencyAccessType.Takeover, GrantorId = savingUser.Id, }; - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(savingUser.Id).Returns(savingUser); - userService.CanAccessPremium(savingUser).Returns(true); + sutProvider.GetDependency() + .CanAccessPremium(savingUser) + .Returns(false); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); + Assert.Contains("Not a premium user.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_EmergencyAccessGrantorIdNotEqualToSavingUserId_ThrowsBadRequest( + SutProvider sutProvider, User savingUser) + { + savingUser.Premium = true; + var emergencyAccess = new EmergencyAccess + { + Type = EmergencyAccessType.Takeover, + GrantorId = new Guid(), + }; + + sutProvider.GetDependency() + .GetUserByIdAsync(savingUser.Id) + .Returns(savingUser); + sutProvider.GetDependency() + .CanAccessPremium(savingUser) + .Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_GrantorUserWithKeyConnectorCannotTakeover_ThrowsBadRequest( + SutProvider sutProvider, User grantorUser) + { + grantorUser.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess + { + Type = EmergencyAccessType.Takeover, + GrantorId = grantorUser.Id, + }; + + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(grantorUser.Id).Returns(grantorUser); + userService.CanAccessPremium(grantorUser).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(emergencyAccess, grantorUser)); + Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); } [Theory, BitAutoData] - public async Task InitiateAsync_UserWithKeyConnectorCannotUseTakeover( + public async Task SaveAsync_GrantorUserWithKeyConnectorCanView_SavesEmergencyAccess( + SutProvider sutProvider, User grantorUser) + { + grantorUser.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess + { + Type = EmergencyAccessType.View, + GrantorId = grantorUser.Id, + }; + + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(grantorUser.Id).Returns(grantorUser); + userService.CanAccessPremium(grantorUser).Returns(true); + + await sutProvider.Sut.SaveAsync(emergencyAccess, grantorUser); + + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(emergencyAccess); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ValidRequest_SavesEmergencyAccess( + SutProvider sutProvider, User grantorUser) + { + grantorUser.UsesKeyConnector = false; + var emergencyAccess = new EmergencyAccess + { + Type = EmergencyAccessType.Takeover, + GrantorId = grantorUser.Id, + }; + + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(grantorUser.Id).Returns(grantorUser); + userService.CanAccessPremium(grantorUser).Returns(true); + + await sutProvider.Sut.SaveAsync(emergencyAccess, grantorUser); + + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(emergencyAccess); + } + + [Theory, BitAutoData] + public async Task InitiateAsync_EmergencyAccessNull_ThrowBadRequest( + SutProvider sutProvider, User initiatingUser) + { + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((EmergencyAccess)null); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .ReplaceAsync(default); + } + + [Theory, BitAutoData] + public async Task InitiateAsync_EmergencyAccessGranteeIdNotEqual_ThrowBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User initiatingUser) + { + emergencyAccess.GranteeId = new Guid(); + sutProvider.GetDependency() + .GetByIdAsync(emergencyAccess.Id) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .ReplaceAsync(default); + } + + [Theory, BitAutoData] + public async Task InitiateAsync_EmergencyAccessStatusIsNotConfirmed_ThrowBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User initiatingUser) + { + emergencyAccess.GranteeId = initiatingUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.Invited; + sutProvider.GetDependency() + .GetByIdAsync(emergencyAccess.Id) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .ReplaceAsync(default); + } + + [Theory, BitAutoData] + public async Task InitiateAsync_UserWithKeyConnectorCannotUseTakeover_ThrowsBadRequest( SutProvider sutProvider, User initiatingUser, User grantor) { grantor.UsesKeyConnector = true; @@ -107,40 +743,711 @@ public class EmergencyAccessServiceTests Type = EmergencyAccessType.Takeover, }; - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + sutProvider.GetDependency() + .GetByIdAsync(grantor.Id) + .Returns(grantor); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser)); Assert.Contains("You cannot takeover an account that is using Key Connector", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .ReplaceAsync(default); } [Theory, BitAutoData] - public async Task TakeoverAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User requestingUser, User grantor) + public async Task InitiateAsync_UserWithKeyConnectorCanView_Success( + SutProvider sutProvider, User initiatingUser, User grantor) + { + grantor.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess + { + Status = EmergencyAccessStatusType.Confirmed, + GranteeId = initiatingUser.Id, + GrantorId = grantor.Id, + Type = EmergencyAccessType.View, + }; + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + sutProvider.GetDependency() + .GetByIdAsync(grantor.Id) + .Returns(grantor); + + await sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser); + + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.RecoveryInitiated)); + } + + [Theory, BitAutoData] + public async Task InitiateAsync_RequestIsCorrect_Success( + SutProvider sutProvider, User initiatingUser, User grantor) + { + var emergencyAccess = new EmergencyAccess + { + Status = EmergencyAccessStatusType.Confirmed, + GranteeId = initiatingUser.Id, + GrantorId = grantor.Id, + Type = EmergencyAccessType.Takeover, + }; + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + sutProvider.GetDependency() + .GetByIdAsync(grantor.Id) + .Returns(grantor); + + await sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser); + + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.RecoveryInitiated)); + } + + [Theory, BitAutoData] + public async Task ApproveAsync_EmergencyAccessNull_ThrowsBadrequest( + SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((EmergencyAccess)null); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ApproveAsync(new Guid(), null)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task ApproveAsync_EmergencyAccessGrantorIdNotEquatToApproving_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User grantorUser) + { + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ApproveAsync(emergencyAccess.Id, grantorUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory] + [BitAutoData(EmergencyAccessStatusType.Invited)] + [BitAutoData(EmergencyAccessStatusType.Accepted)] + [BitAutoData(EmergencyAccessStatusType.Confirmed)] + [BitAutoData(EmergencyAccessStatusType.RecoveryApproved)] + public async Task ApproveAsync_EmergencyAccessStatusNotRecoveryInitiated_ThrowsBadRequest( + EmergencyAccessStatusType statusType, + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User grantorUser) + { + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.Status = statusType; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ApproveAsync(emergencyAccess.Id, grantorUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task ApproveAsync_Success( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User grantorUser, + User granteeUser) + { + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(granteeUser); + + await sutProvider.Sut.ApproveAsync(emergencyAccess.Id, grantorUser); + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.RecoveryApproved)); + } + + [Theory, BitAutoData] + public async Task RejectAsync_EmergencyAccessIdNull_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User GrantorUser) + { + emergencyAccess.GrantorId = GrantorUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.Accepted; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((EmergencyAccess)null); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RejectAsync(emergencyAccess.Id, GrantorUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task RejectAsync_EmergencyAccessGrantorIdNotEqualToRequestUser_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User GrantorUser) + { + emergencyAccess.Status = EmergencyAccessStatusType.Accepted; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RejectAsync(emergencyAccess.Id, GrantorUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory] + [BitAutoData(EmergencyAccessStatusType.Invited)] + [BitAutoData(EmergencyAccessStatusType.Accepted)] + [BitAutoData(EmergencyAccessStatusType.Confirmed)] + public async Task RejectAsync_EmergencyAccessStatusNotValid_ThrowsBadRequest( + EmergencyAccessStatusType statusType, + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User GrantorUser) + { + emergencyAccess.GrantorId = GrantorUser.Id; + emergencyAccess.Status = statusType; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RejectAsync(emergencyAccess.Id, GrantorUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory] + [BitAutoData(EmergencyAccessStatusType.RecoveryInitiated)] + [BitAutoData(EmergencyAccessStatusType.RecoveryApproved)] + public async Task RejectAsync_Success( + EmergencyAccessStatusType statusType, + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User GrantorUser, + User GranteeUser) + { + emergencyAccess.GrantorId = GrantorUser.Id; + emergencyAccess.Status = statusType; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(GranteeUser); + + await sutProvider.Sut.RejectAsync(emergencyAccess.Id, GrantorUser); + + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.Confirmed)); + } + + [Theory, BitAutoData] + public async Task GetPoliciesAsync_RequestNotValidEmergencyAccessNull_ThrowsBadRequest( + SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((EmergencyAccess)null); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.GetPoliciesAsync(default, default)); + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory] + [BitAutoData(EmergencyAccessStatusType.Invited)] + [BitAutoData(EmergencyAccessStatusType.Accepted)] + [BitAutoData(EmergencyAccessStatusType.Confirmed)] + [BitAutoData(EmergencyAccessStatusType.RecoveryInitiated)] + public async Task GetPoliciesAsync_RequestNotValidStatusType_ThrowsBadRequest( + EmergencyAccessStatusType statusType, + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Status = statusType; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.GetPoliciesAsync(emergencyAccess.Id, granteeUser)); + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task GetPoliciesAsync_RequestNotValidType_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.View; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.GetPoliciesAsync(emergencyAccess.Id, granteeUser)); + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory] + [BitAutoData(OrganizationUserType.Admin)] + [BitAutoData(OrganizationUserType.User)] + [BitAutoData(OrganizationUserType.Custom)] + public async Task GetPoliciesAsync_OrganizationUserTypeNotOwner_ReturnsNull( + OrganizationUserType userType, + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser, + User grantorUser, + OrganizationUser grantorOrganizationUser) + { + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency() + .GetByIdAsync(emergencyAccess.GrantorId) + .Returns(grantorUser); + + grantorOrganizationUser.UserId = grantorUser.Id; + grantorOrganizationUser.Type = userType; + sutProvider.GetDependency() + .GetManyByUserAsync(grantorUser.Id) + .Returns([grantorOrganizationUser]); + + var result = await sutProvider.Sut.GetPoliciesAsync(emergencyAccess.Id, granteeUser); + Assert.Null(result); + } + + [Theory, BitAutoData] + public async Task GetPoliciesAsync_OrganizationUserEmpty_ReturnsNull( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser, + User grantorUser) + { + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency() + .GetByIdAsync(emergencyAccess.GrantorId) + .Returns(grantorUser); + + sutProvider.GetDependency() + .GetManyByUserAsync(grantorUser.Id) + .Returns([]); + + + var result = await sutProvider.Sut.GetPoliciesAsync(emergencyAccess.Id, granteeUser); + Assert.Null(result); + } + + [Theory, BitAutoData] + public async Task GetPoliciesAsync_ReturnsNotNull( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser, + User grantorUser, + OrganizationUser grantorOrganizationUser) + { + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency() + .GetByIdAsync(emergencyAccess.GrantorId) + .Returns(grantorUser); + + grantorOrganizationUser.UserId = grantorUser.Id; + grantorOrganizationUser.Type = OrganizationUserType.Owner; + sutProvider.GetDependency() + .GetManyByUserAsync(grantorUser.Id) + .Returns([grantorOrganizationUser]); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(grantorUser.Id) + .Returns([]); + + var result = await sutProvider.Sut.GetPoliciesAsync(emergencyAccess.Id, granteeUser); + Assert.NotNull(result); + } + + [Theory, BitAutoData] + public async Task TakeoverAsync_RequestNotValid_EmergencyAccessIsNull_ThrowsBadRequest( + SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((EmergencyAccess)null); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.TakeoverAsync(default, default)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task TakeoverAsync_RequestNotValid_GranteeNotEqualToRequestingUser_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser) + { + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.TakeoverAsync(new Guid(), granteeUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory] + [BitAutoData(EmergencyAccessStatusType.Invited)] + [BitAutoData(EmergencyAccessStatusType.Accepted)] + [BitAutoData(EmergencyAccessStatusType.Confirmed)] + [BitAutoData(EmergencyAccessStatusType.RecoveryInitiated)] + public async Task TakeoverAsync_RequestNotValid_StatusType_ThrowsBadRequest( + EmergencyAccessStatusType statusType, + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Status = statusType; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.TakeoverAsync(new Guid(), granteeUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task TakeoverAsync_RequestNotValid_TypeIsView_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.View; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.TakeoverAsync(new Guid(), granteeUser)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task TakeoverAsync_UserWithKeyConnectorCannotUseTakeover_ThrowsBadRequest( + SutProvider sutProvider, + User granteeUser, + User grantor) { grantor.UsesKeyConnector = true; var emergencyAccess = new EmergencyAccess { GrantorId = grantor.Id, - GranteeId = requestingUser.Id, + GranteeId = granteeUser.Id, Status = EmergencyAccessStatusType.RecoveryApproved, Type = EmergencyAccessType.Takeover, }; - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency() + .GetByIdAsync(grantor.Id) + .Returns(grantor); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.TakeoverAsync(new Guid(), requestingUser)); + () => sutProvider.Sut.TakeoverAsync(new Guid(), granteeUser)); Assert.Contains("You cannot takeover an account that is using Key Connector", exception.Message); } [Theory, BitAutoData] - public async Task PasswordAsync_Disables_2FA_Providers_On_The_Grantor( + public async Task TakeoverAsync_Success_ReturnsEmergencyAccessAndGrantorUser( + SutProvider sutProvider, + User granteeUser, + User grantor) + { + grantor.UsesKeyConnector = false; + var emergencyAccess = new EmergencyAccess + { + GrantorId = grantor.Id, + GranteeId = granteeUser.Id, + Status = EmergencyAccessStatusType.RecoveryApproved, + Type = EmergencyAccessType.Takeover, + }; + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency() + .GetByIdAsync(grantor.Id) + .Returns(grantor); + + var result = await sutProvider.Sut.TakeoverAsync(new Guid(), granteeUser); + + Assert.Equal(result.Item1, emergencyAccess); + Assert.Equal(result.Item2, grantor); + } + + [Theory, BitAutoData] + public async Task PasswordAsync_RequestNotValid_EmergencyAccessIsNull_ThrowsBadRequest( + SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((EmergencyAccess)null); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PasswordAsync(default, default, default, default)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task PasswordAsync_RequestNotValid_GranteeNotEqualToRequestingUser_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser) + { + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PasswordAsync(emergencyAccess.Id, granteeUser, default, default)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory] + [BitAutoData(EmergencyAccessStatusType.Invited)] + [BitAutoData(EmergencyAccessStatusType.Accepted)] + [BitAutoData(EmergencyAccessStatusType.Confirmed)] + [BitAutoData(EmergencyAccessStatusType.RecoveryInitiated)] + public async Task PasswordAsync_RequestNotValid_StatusType_ThrowsBadRequest( + EmergencyAccessStatusType statusType, + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Status = statusType; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PasswordAsync(emergencyAccess.Id, granteeUser, default, default)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task PasswordAsync_RequestNotValid_TypeIsView_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.View; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PasswordAsync(emergencyAccess.Id, granteeUser, default, default)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + } + + [Theory, BitAutoData] + public async Task PasswordAsync_NonOrgUser_Success( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser, + User grantorUser, + string key, + string passwordHash) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency() + .GetByIdAsync(emergencyAccess.GrantorId) + .Returns(grantorUser); + + await sutProvider.Sut.PasswordAsync(emergencyAccess.Id, granteeUser, passwordHash, key); + + await sutProvider.GetDependency() + .Received(1) + .UpdatePasswordHash(grantorUser, passwordHash); + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(Arg.Is(u => u.VerifyDevices == false && u.Key == key)); + } + + [Theory] + [BitAutoData(OrganizationUserType.User)] + [BitAutoData(OrganizationUserType.Admin)] + [BitAutoData(OrganizationUserType.Custom)] + public async Task PasswordAsync_OrgUser_NotOrganizationOwner_RemovedFromOrganization_Success( + OrganizationUserType userType, + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser, + User grantorUser, + OrganizationUser organizationUser, + string key, + string passwordHash) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency() + .GetByIdAsync(emergencyAccess.GrantorId) + .Returns(grantorUser); + + organizationUser.UserId = grantorUser.Id; + organizationUser.Type = userType; + sutProvider.GetDependency() + .GetManyByUserAsync(grantorUser.Id) + .Returns([organizationUser]); + + await sutProvider.Sut.PasswordAsync(emergencyAccess.Id, granteeUser, passwordHash, key); + + await sutProvider.GetDependency() + .Received(1) + .UpdatePasswordHash(grantorUser, passwordHash); + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(Arg.Is(u => u.VerifyDevices == false && u.Key == key)); + await sutProvider.GetDependency() + .Received(1) + .RemoveUserAsync(organizationUser.OrganizationId, organizationUser.UserId.Value); + } + + [Theory, BitAutoData] + public async Task PasswordAsync_OrgUser_IsOrganizationOwner_NotRemovedFromOrganization_Success( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser, + User grantorUser, + OrganizationUser organizationUser, + string key, + string passwordHash) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.GrantorId = grantorUser.Id; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + sutProvider.GetDependency() + .GetByIdAsync(emergencyAccess.GrantorId) + .Returns(grantorUser); + + organizationUser.UserId = grantorUser.Id; + organizationUser.Type = OrganizationUserType.Owner; + sutProvider.GetDependency() + .GetManyByUserAsync(grantorUser.Id) + .Returns([organizationUser]); + + await sutProvider.Sut.PasswordAsync(emergencyAccess.Id, granteeUser, passwordHash, key); + + await sutProvider.GetDependency() + .Received(1) + .UpdatePasswordHash(grantorUser, passwordHash); + await sutProvider.GetDependency() + .Received(1) + .ReplaceAsync(Arg.Is(u => u.VerifyDevices == false && u.Key == key)); + await sutProvider.GetDependency() + .Received(0) + .RemoveUserAsync(organizationUser.OrganizationId, organizationUser.UserId.Value); + } + + [Theory, BitAutoData] + public async Task PasswordAsync_Disables_NewDeviceVerification_And_TwoFactorProviders_On_The_Grantor( SutProvider sutProvider, User requestingUser, User grantor) { grantor.UsesKeyConnector = true; @@ -160,12 +1467,49 @@ public class EmergencyAccessServiceTests Type = EmergencyAccessType.Takeover, }; - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + sutProvider.GetDependency() + .GetByIdAsync(grantor.Id) + .Returns(grantor); await sutProvider.Sut.PasswordAsync(Guid.NewGuid(), requestingUser, "blablahash", "blablakey"); Assert.Empty(grantor.GetTwoFactorProviders()); + Assert.False(grantor.VerifyDevices); await sutProvider.GetDependency().Received().ReplaceAsync(grantor); } + + [Theory, BitAutoData] + public async Task ViewAsync_EmergencyAccessTypeNotView_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ViewAsync(emergencyAccess.Id, granteeUser)); + } + + [Theory, BitAutoData] + public async Task GetAttachmentDownloadAsync_EmergencyAccessTypeNotView_ThrowsBadRequest( + SutProvider sutProvider, + EmergencyAccess emergencyAccess, + User granteeUser) + { + emergencyAccess.GranteeId = granteeUser.Id; + emergencyAccess.Type = EmergencyAccessType.Takeover; + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(emergencyAccess); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.GetAttachmentDownloadAsync(emergencyAccess.Id, default, default, granteeUser)); + } } diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index b1f78ed987..3fb134fda8 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -3,14 +3,11 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; using Bit.Core.Services; using Bit.Core.Settings; -using Bit.Core.Test.Billing.Tax.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Braintree; @@ -195,7 +192,7 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); ; + .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); } #endregion @@ -1029,7 +1026,7 @@ public class SubscriberServiceTests stripeAdapter .PaymentMethodListAutoPagingAsync(Arg.Any()) - .Returns(GetPaymentMethodsAsync(new List())); + .Returns(GetPaymentMethodsAsync(new List())); await sutProvider.Sut.RemovePaymentSource(organization); @@ -1061,7 +1058,7 @@ public class SubscriberServiceTests stripeAdapter .PaymentMethodListAutoPagingAsync(Arg.Any()) - .Returns(GetPaymentMethodsAsync(new List + .Returns(GetPaymentMethodsAsync(new List { new () { @@ -1086,8 +1083,8 @@ public class SubscriberServiceTests .PaymentMethodDetachAsync(cardId); } - private static async IAsyncEnumerable GetPaymentMethodsAsync( - IEnumerable paymentMethods) + private static async IAsyncEnumerable GetPaymentMethodsAsync( + IEnumerable paymentMethods) { foreach (var paymentMethod in paymentMethods) { @@ -1598,14 +1595,22 @@ public class SubscriberServiceTests City = "Example Town", State = "NY" }, - TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } + TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] }, + Subscriptions = new StripeList + { + Data = [ + new Subscription + { + Id = provider.GatewaySubscriptionId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + } }); var subscription = new Subscription { Items = new StripeList() }; sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) .Returns(subscription); - sutProvider.GetDependency().CreateAsync(Arg.Any()) - .Returns(new FakeAutomaticTaxStrategy(true)); await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); @@ -1623,6 +1628,98 @@ public class SubscriberServiceTests await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is( options => options.Type == "us_ein" && options.Value == taxInformation.TaxId)); + + await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); + } + + [Theory, BitAutoData] + public async Task UpdateTaxInformation_NonUser_ReverseCharge_MakesCorrectInvocations( + Provider provider, + SutProvider sutProvider) + { + var stripeAdapter = sutProvider.GetDependency(); + + var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } }; + + stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( + options => options.Expand.Contains("tax_ids"))).Returns(customer); + + var taxInformation = new TaxInformation( + "CA", + "12345", + "123456789", + "us_ein", + "123 Example St.", + null, + "Example Town", + "NY"); + + sutProvider.GetDependency() + .CustomerUpdateAsync( + Arg.Is(p => p == provider.GatewayCustomerId), + Arg.Is(options => + options.Address.Country == "CA" && + options.Address.PostalCode == "12345" && + options.Address.Line1 == "123 Example St." && + options.Address.Line2 == null && + options.Address.City == "Example Town" && + options.Address.State == "NY")) + .Returns(new Customer + { + Id = provider.GatewayCustomerId, + Address = new Address + { + Country = "CA", + PostalCode = "12345", + Line1 = "123 Example St.", + Line2 = null, + City = "Example Town", + State = "NY" + }, + TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] }, + Subscriptions = new StripeList + { + Data = [ + new Subscription + { + Id = provider.GatewaySubscriptionId, + CustomerId = provider.GatewayCustomerId, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false } + } + ] + } + }); + + var subscription = new Subscription { Items = new StripeList() }; + sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) + .Returns(subscription); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM21092_SetNonUSBusinessUseToReverseCharge).Returns(true); + + await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); + + await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + options => + options.Address.Country == taxInformation.Country && + options.Address.PostalCode == taxInformation.PostalCode && + options.Address.Line1 == taxInformation.Line1 && + options.Address.Line2 == taxInformation.Line2 && + options.Address.City == taxInformation.City && + options.Address.State == taxInformation.State)); + + await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1"); + + await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is( + options => options.Type == "us_ein" && + options.Value == taxInformation.TaxId)); + + await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, + Arg.Is(options => options.TaxExempt == StripeConstants.TaxExempt.Reverse)); + + await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + Arg.Is(options => options.AutomaticTax.Enabled == true)); } #endregion diff --git a/test/Core.Test/Models/Business/OrganizationLicenseFileFixtures.cs b/test/Core.Test/Models/Business/OrganizationLicenseFileFixtures.cs index 1004cefeca..08771df06a 100644 --- a/test/Core.Test/Models/Business/OrganizationLicenseFileFixtures.cs +++ b/test/Core.Test/Models/Business/OrganizationLicenseFileFixtures.cs @@ -28,7 +28,10 @@ public static class OrganizationLicenseFileFixtures private const string Version15 = "{\n 'LicenseKey': 'myLicenseKey',\n 'InstallationId': '78900000-0000-0000-0000-000000000123',\n 'Id': '12300000-0000-0000-0000-000000000456',\n 'Name': 'myOrg',\n 'BillingEmail': 'myBillingEmail',\n 'BusinessName': 'myBusinessName',\n 'Enabled': true,\n 'Plan': 'myPlan',\n 'PlanType': 11,\n 'Seats': 10,\n 'MaxCollections': 2,\n 'UsePolicies': true,\n 'UseSso': true,\n 'UseKeyConnector': true,\n 'UseScim': true,\n 'UseGroups': true,\n 'UseEvents': true,\n 'UseDirectory': true,\n 'UseTotp': true,\n 'Use2fa': true,\n 'UseApi': true,\n 'UseResetPassword': true,\n 'MaxStorageGb': 100,\n 'SelfHost': true,\n 'UsersGetPremium': true,\n 'UseCustomPermissions': true,\n 'Version': 14,\n 'Issued': '2023-12-14T02:03:33.374297Z',\n 'Refresh': '2023-12-07T22:42:33.970597Z',\n 'Expires': '2023-12-21T02:03:33.374297Z',\n 'ExpirationWithoutGracePeriod': null,\n 'UsePasswordManager': true,\n 'UseSecretsManager': true,\n 'SmSeats': 5,\n 'SmServiceAccounts': 8,\n 'LimitCollectionCreationDeletion': true,\n 'AllowAdminAccessToAllCollectionItems': true,\n 'Trial': true,\n 'LicenseType': 1,\n 'Hash': 'EZl4IvJaa1E5mPmlfp4p5twAtlmaxlF1yoZzVYP4vog=',\n 'Signature': ''\n}"; - private static readonly Dictionary LicenseVersions = new() { { 12, Version12 }, { 13, Version13 }, { 14, Version14 }, { 15, Version15 } }; + private const string Version16 = + "{\n'LicenseKey': 'myLicenseKey',\n'InstallationId': '78900000-0000-0000-0000-000000000123',\n'Id': '12300000-0000-0000-0000-000000000456',\n'Name': 'myOrg',\n'BillingEmail': 'myBillingEmail',\n'BusinessName': 'myBusinessName',\n'Enabled': true,\n'Plan': 'myPlan',\n'PlanType': 11,\n'Seats': 10,\n'MaxCollections': 2,\n'UsePolicies': true,\n'UseSso': true,\n'UseKeyConnector': true,\n'UseScim': true,\n'UseGroups': true,\n'UseEvents': true,\n'UseDirectory': true,\n'UseTotp': true,\n'Use2fa': true,\n'UseApi': true,\n'UseResetPassword': true,\n'MaxStorageGb': 100,\n'SelfHost': true,\n'UsersGetPremium': true,\n'UseCustomPermissions': true,\n'Version': 15,\n'Issued': '2025-05-16T20:50:09.036931Z',\n'Refresh': '2025-05-23T20:50:09.036931Z',\n'Expires': '2025-05-23T20:50:09.036931Z',\n'ExpirationWithoutGracePeriod': null,\n'UsePasswordManager': true,\n'UseSecretsManager': true,\n'SmSeats': 5,\n'SmServiceAccounts': 8,\n'UseRiskInsights': false,\n'LimitCollectionCreationDeletion': true,\n'AllowAdminAccessToAllCollectionItems': true,\n'Trial': true,\n'LicenseType': 1,\n'UseOrganizationDomains': true,\n'UseAdminSponsoredFamilies': false,\n'Hash': 'k3M9SpHKUo0TmuSnNipeZleCHxgcEycKRXYl9BAg30Q=',\n'Signature': '',\n'Token': null\n}"; + + private static readonly Dictionary LicenseVersions = new() { { 12, Version12 }, { 13, Version13 }, { 14, Version14 }, { 15, Version15 }, { 16, Version16 } }; public static OrganizationLicense GetVersion(int licenseVersion) { diff --git a/test/Core.Test/Services/UserServiceTests.cs b/test/Core.Test/Services/UserServiceTests.cs index 0458c7cdd9..ac7f6e4018 100644 --- a/test/Core.Test/Services/UserServiceTests.cs +++ b/test/Core.Test/Services/UserServiceTests.cs @@ -347,7 +347,7 @@ public class UserServiceTests SutProvider sutProvider, Guid userId, Organization organization) { organization.Enabled = true; - organization.UseSso = true; + organization.UseOrganizationDomains = true; sutProvider.GetDependency() .GetByVerifiedUserEmailDomainAsync(userId) @@ -362,7 +362,7 @@ public class UserServiceTests SutProvider sutProvider, Guid userId, Organization organization) { organization.Enabled = false; - organization.UseSso = true; + organization.UseOrganizationDomains = true; sutProvider.GetDependency() .GetByVerifiedUserEmailDomainAsync(userId) @@ -373,11 +373,11 @@ public class UserServiceTests } [Theory, BitAutoData] - public async Task IsClaimedByAnyOrganizationAsync_WithOrganizationUseSsoFalse_ReturnsFalse( + public async Task IsClaimedByAnyOrganizationAsync_WithOrganizationUseOrganizationDomaisFalse_ReturnsFalse( SutProvider sutProvider, Guid userId, Organization organization) { organization.Enabled = true; - organization.UseSso = false; + organization.UseOrganizationDomains = false; sutProvider.GetDependency() .GetByVerifiedUserEmailDomainAsync(userId)