diff --git a/.github/renovate.json b/.github/renovate.json index 18d6e0bb61..91774ca33e 100644 --- a/.github/renovate.json +++ b/.github/renovate.json @@ -44,6 +44,7 @@ "matchPackageNames": [ "AspNetCoreRateLimit", "AspNetCoreRateLimit.Redis", + "Azure.Data.Tables", "Azure.Extensions.AspNetCore.DataProtection.Blobs", "Azure.Messaging.EventGrid", "Azure.Messaging.ServiceBus", @@ -53,7 +54,6 @@ "Fido2.AspNet", "Duende.IdentityServer", "Microsoft.Azure.Cosmos", - "Microsoft.Azure.Cosmos.Table", "Microsoft.Extensions.Caching.StackExchangeRedis", "Microsoft.Extensions.Identity.Stores", "Otp.NET", diff --git a/.github/workflows/scan.yml b/.github/workflows/scan.yml index 62203804b9..89d75ccf0f 100644 --- a/.github/workflows/scan.yml +++ b/.github/workflows/scan.yml @@ -10,8 +10,6 @@ on: pull_request_target: types: [opened, synchronize] -permissions: read-all - jobs: check-run: name: Check PR run @@ -22,6 +20,8 @@ jobs: runs-on: ubuntu-22.04 needs: check-run permissions: + contents: read + pull-requests: write security-events: write steps: @@ -43,7 +43,7 @@ jobs: additional_params: --report-format sarif --output-path . ${{ env.INCREMENTAL }} - name: Upload Checkmarx results to GitHub - uses: github/codeql-action/upload-sarif@8a470fddafa5cbb6266ee11b37ef4d8aae19c571 # v3.24.6 + uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9 with: sarif_file: cx_result.sarif @@ -51,6 +51,9 @@ jobs: name: Quality scan runs-on: ubuntu-22.04 needs: check-run + permissions: + contents: read + pull-requests: write steps: - name: Check out repo diff --git a/bitwarden_license/src/Scim/appsettings.json b/bitwarden_license/src/Scim/appsettings.json index 630896a65f..dcdfeb3ede 100644 --- a/bitwarden_license/src/Scim/appsettings.json +++ b/bitwarden_license/src/Scim/appsettings.json @@ -30,10 +30,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "documentDb": { - "uri": "SECRET", - "key": "SECRET" - }, "sentry": { "dsn": "SECRET" }, @@ -58,6 +54,5 @@ "region": "SECRET" } }, - "scimSettings": { - } + "scimSettings": {} } diff --git a/bitwarden_license/src/Sso/appsettings.json b/bitwarden_license/src/Sso/appsettings.json index 3bf02cd869..73c85044cc 100644 --- a/bitwarden_license/src/Sso/appsettings.json +++ b/bitwarden_license/src/Sso/appsettings.json @@ -31,10 +31,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "documentDb": { - "uri": "SECRET", - "key": "SECRET" - }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" diff --git a/src/Admin/Jobs/JobsHostedService.cs b/src/Admin/Jobs/JobsHostedService.cs index adba27970c..89cf5512c3 100644 --- a/src/Admin/Jobs/JobsHostedService.cs +++ b/src/Admin/Jobs/JobsHostedService.cs @@ -76,14 +76,18 @@ public class JobsHostedService : BaseJobsHostedService { new Tuple(typeof(DeleteSendsJob), everyFiveMinutesTrigger), new Tuple(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger), - new Tuple(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger), - new Tuple(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger), new Tuple(typeof(DeleteCiphersJob), everyDayAtMidnightUtc), new Tuple(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger), new Tuple(typeof(DeleteAuthRequestsJob), everyFifteenMinutesTrigger), new Tuple(typeof(DeleteUnverifiedOrganizationDomainsJob), everyDayAtTwoAmUtcTrigger), }; + if (!(_globalSettings.SqlServer?.DisableDatabaseMaintenanceJobs ?? false)) + { + jobs.Add(new Tuple(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger)); + jobs.Add(new Tuple(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger)); + } + if (!_globalSettings.SelfHosted) { jobs.Add(new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger)); diff --git a/src/Admin/Startup.cs b/src/Admin/Startup.cs index db870266cc..788908d42a 100644 --- a/src/Admin/Startup.cs +++ b/src/Admin/Startup.cs @@ -88,7 +88,7 @@ public class Startup services.AddBaseServices(globalSettings); services.AddDefaultServices(globalSettings); services.AddScoped(); - services.AddBillingCommands(); + services.AddBillingOperations(); #if OSS services.AddOosServices(); diff --git a/src/Admin/appsettings.json b/src/Admin/appsettings.json index 4764484204..9513dc44a2 100644 --- a/src/Admin/appsettings.json +++ b/src/Admin/appsettings.json @@ -30,10 +30,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "documentDb": { - "uri": "SECRET", - "key": "SECRET" - }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" diff --git a/src/Api/AdminConsole/Controllers/OrganizationsController.cs b/src/Api/AdminConsole/Controllers/OrganizationsController.cs index 2a4ba3a1db..822f9635eb 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationsController.cs @@ -66,7 +66,7 @@ public class OrganizationsController : Controller private readonly IAddSecretsManagerSubscriptionCommand _addSecretsManagerSubscriptionCommand; private readonly IPushNotificationService _pushNotificationService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly IGetSubscriptionQuery _getSubscriptionQuery; + private readonly ISubscriberQueries _subscriberQueries; private readonly IReferenceEventService _referenceEventService; private readonly IOrganizationEnableCollectionEnhancementsCommand _organizationEnableCollectionEnhancementsCommand; @@ -93,7 +93,7 @@ public class OrganizationsController : Controller IAddSecretsManagerSubscriptionCommand addSecretsManagerSubscriptionCommand, IPushNotificationService pushNotificationService, ICancelSubscriptionCommand cancelSubscriptionCommand, - IGetSubscriptionQuery getSubscriptionQuery, + ISubscriberQueries subscriberQueries, IReferenceEventService referenceEventService, IOrganizationEnableCollectionEnhancementsCommand organizationEnableCollectionEnhancementsCommand) { @@ -119,7 +119,7 @@ public class OrganizationsController : Controller _addSecretsManagerSubscriptionCommand = addSecretsManagerSubscriptionCommand; _pushNotificationService = pushNotificationService; _cancelSubscriptionCommand = cancelSubscriptionCommand; - _getSubscriptionQuery = getSubscriptionQuery; + _subscriberQueries = subscriberQueries; _referenceEventService = referenceEventService; _organizationEnableCollectionEnhancementsCommand = organizationEnableCollectionEnhancementsCommand; } @@ -479,7 +479,7 @@ public class OrganizationsController : Controller throw new NotFoundException(); } - var subscription = await _getSubscriptionQuery.GetSubscription(organization); + var subscription = await _subscriberQueries.GetSubscriptionOrThrow(organization); await _cancelSubscriptionCommand.CancelSubscription(subscription, new OffboardingSurveyResponse diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs index df75a34f69..767f83ee22 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs @@ -126,8 +126,14 @@ public class OrganizationSubscriptionResponseModel : OrganizationResponseModel if (hideSensitiveData) { BillingEmail = null; - Subscription.Items = null; - UpcomingInvoice.Amount = null; + if (Subscription != null) + { + Subscription.Items = null; + } + if (UpcomingInvoice != null) + { + UpcomingInvoice.Amount = null; + } } } diff --git a/src/Api/Auth/Controllers/AccountsController.cs b/src/Api/Auth/Controllers/AccountsController.cs index 29ede684be..5f1910fb28 100644 --- a/src/Api/Auth/Controllers/AccountsController.cs +++ b/src/Api/Auth/Controllers/AccountsController.cs @@ -69,7 +69,7 @@ public class AccountsController : Controller private readonly IRotateUserKeyCommand _rotateUserKeyCommand; private readonly IFeatureService _featureService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly IGetSubscriptionQuery _getSubscriptionQuery; + private readonly ISubscriberQueries _subscriberQueries; private readonly IReferenceEventService _referenceEventService; private readonly ICurrentContext _currentContext; @@ -104,7 +104,7 @@ public class AccountsController : Controller IRotateUserKeyCommand rotateUserKeyCommand, IFeatureService featureService, ICancelSubscriptionCommand cancelSubscriptionCommand, - IGetSubscriptionQuery getSubscriptionQuery, + ISubscriberQueries subscriberQueries, IReferenceEventService referenceEventService, ICurrentContext currentContext, IRotationValidator, IEnumerable> cipherValidator, @@ -133,7 +133,7 @@ public class AccountsController : Controller _rotateUserKeyCommand = rotateUserKeyCommand; _featureService = featureService; _cancelSubscriptionCommand = cancelSubscriptionCommand; - _getSubscriptionQuery = getSubscriptionQuery; + _subscriberQueries = subscriberQueries; _referenceEventService = referenceEventService; _currentContext = currentContext; _cipherValidator = cipherValidator; @@ -831,7 +831,7 @@ public class AccountsController : Controller throw new UnauthorizedAccessException(); } - var subscription = await _getSubscriptionQuery.GetSubscription(user); + var subscription = await _subscriberQueries.GetSubscriptionOrThrow(user); await _cancelSubscriptionCommand.CancelSubscription(subscription, new OffboardingSurveyResponse diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs new file mode 100644 index 0000000000..583a5937e4 --- /dev/null +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -0,0 +1,44 @@ +using Bit.Api.Billing.Models; +using Bit.Core; +using Bit.Core.Billing.Queries; +using Bit.Core.Context; +using Bit.Core.Services; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.Billing.Controllers; + +[Route("providers/{providerId:guid}/billing")] +[Authorize("Application")] +public class ProviderBillingController( + ICurrentContext currentContext, + IFeatureService featureService, + IProviderBillingQueries providerBillingQueries) : Controller +{ + [HttpGet("subscription")] + public async Task GetSubscriptionAsync([FromRoute] Guid providerId) + { + if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + { + return TypedResults.NotFound(); + } + + if (!currentContext.ProviderProviderAdmin(providerId)) + { + return TypedResults.Unauthorized(); + } + + var subscriptionData = await providerBillingQueries.GetSubscriptionData(providerId); + + if (subscriptionData == null) + { + return TypedResults.NotFound(); + } + + var (providerPlans, subscription) = subscriptionData; + + var providerSubscriptionDTO = ProviderSubscriptionDTO.From(providerPlans, subscription); + + return TypedResults.Ok(providerSubscriptionDTO); + } +} diff --git a/src/Api/Billing/Controllers/ProviderOrganizationController.cs b/src/Api/Billing/Controllers/ProviderOrganizationController.cs new file mode 100644 index 0000000000..a5cc31c79c --- /dev/null +++ b/src/Api/Billing/Controllers/ProviderOrganizationController.cs @@ -0,0 +1,63 @@ +using Bit.Api.Billing.Models; +using Bit.Core; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Commands; +using Bit.Core.Context; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.Billing.Controllers; + +[Route("providers/{providerId:guid}/organizations")] +public class ProviderOrganizationController( + IAssignSeatsToClientOrganizationCommand assignSeatsToClientOrganizationCommand, + ICurrentContext currentContext, + IFeatureService featureService, + ILogger logger, + IOrganizationRepository organizationRepository, + IProviderRepository providerRepository, + IProviderOrganizationRepository providerOrganizationRepository) : Controller +{ + [HttpPut("{providerOrganizationId:guid}")] + public async Task UpdateAsync( + [FromRoute] Guid providerId, + [FromRoute] Guid providerOrganizationId, + [FromBody] UpdateProviderOrganizationRequestBody requestBody) + { + if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + { + return TypedResults.NotFound(); + } + + if (!currentContext.ProviderProviderAdmin(providerId)) + { + return TypedResults.Unauthorized(); + } + + var provider = await providerRepository.GetByIdAsync(providerId); + + var providerOrganization = await providerOrganizationRepository.GetByIdAsync(providerOrganizationId); + + if (provider == null || providerOrganization == null) + { + return TypedResults.NotFound(); + } + + var organization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId); + + if (organization == null) + { + logger.LogError("The organization ({OrganizationID}) represented by provider organization ({ProviderOrganizationID}) could not be found.", providerOrganization.OrganizationId, providerOrganization.Id); + + return TypedResults.Problem(); + } + + await assignSeatsToClientOrganizationCommand.AssignSeatsToClientOrganization( + provider, + organization, + requestBody.AssignedSeats); + + return TypedResults.Ok(); + } +} diff --git a/src/Api/Billing/Models/ProviderSubscriptionDTO.cs b/src/Api/Billing/Models/ProviderSubscriptionDTO.cs new file mode 100644 index 0000000000..ad0714967d --- /dev/null +++ b/src/Api/Billing/Models/ProviderSubscriptionDTO.cs @@ -0,0 +1,49 @@ +using Bit.Core.Billing.Models; +using Bit.Core.Utilities; +using Stripe; + +namespace Bit.Api.Billing.Models; + +public record ProviderSubscriptionDTO( + string Status, + DateTime CurrentPeriodEndDate, + decimal? DiscountPercentage, + IEnumerable Plans) +{ + private const string _annualCadence = "Annual"; + private const string _monthlyCadence = "Monthly"; + + public static ProviderSubscriptionDTO From( + IEnumerable providerPlans, + Subscription subscription) + { + var providerPlansDTO = providerPlans + .Select(providerPlan => + { + var plan = StaticStore.GetPlan(providerPlan.PlanType); + var cost = (providerPlan.SeatMinimum + providerPlan.PurchasedSeats) * plan.PasswordManager.SeatPrice; + var cadence = plan.IsAnnual ? _annualCadence : _monthlyCadence; + return new ProviderPlanDTO( + plan.Name, + providerPlan.SeatMinimum, + providerPlan.PurchasedSeats, + providerPlan.AssignedSeats, + cost, + cadence); + }); + + return new ProviderSubscriptionDTO( + subscription.Status, + subscription.CurrentPeriodEnd, + subscription.Customer?.Discount?.Coupon?.PercentOff, + providerPlansDTO); + } +} + +public record ProviderPlanDTO( + string PlanName, + int SeatMinimum, + int PurchasedSeats, + int AssignedSeats, + decimal Cost, + string Cadence); diff --git a/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs b/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs new file mode 100644 index 0000000000..7bac8fdef4 --- /dev/null +++ b/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs @@ -0,0 +1,6 @@ +namespace Bit.Api.Billing.Models; + +public class UpdateProviderOrganizationRequestBody +{ + public int AssignedSeats { get; set; } +} diff --git a/src/Api/Controllers/CollectionsController.cs b/src/Api/Controllers/CollectionsController.cs index 3eeae17a50..7711e44220 100644 --- a/src/Api/Controllers/CollectionsController.cs +++ b/src/Api/Controllers/CollectionsController.cs @@ -2,7 +2,6 @@ using Bit.Api.Models.Response; using Bit.Api.Utilities; using Bit.Api.Vault.AuthorizationHandlers.Collections; -using Bit.Core; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -11,7 +10,6 @@ using Bit.Core.Models.Data; using Bit.Core.OrganizationFeatures.OrganizationCollections.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -322,7 +320,6 @@ public class CollectionsController : Controller } [HttpPost("bulk-access")] - [RequireFeature(FeatureFlagKeys.BulkCollectionAccess)] public async Task PostBulkCollectionAccess(Guid orgId, [FromBody] BulkCollectionAccessRequestModel model) { // Authorization logic assumes flexible collections is enabled diff --git a/src/Api/Models/Response/SubscriptionResponseModel.cs b/src/Api/Models/Response/SubscriptionResponseModel.cs index 7ba2b857eb..cca4f8ae72 100644 --- a/src/Api/Models/Response/SubscriptionResponseModel.cs +++ b/src/Api/Models/Response/SubscriptionResponseModel.cs @@ -75,6 +75,10 @@ public class BillingSubscription { Items = sub.Items.Select(i => new BillingSubscriptionItem(i)); } + CollectionMethod = sub.CollectionMethod; + SuspensionDate = sub.SuspensionDate; + UnpaidPeriodEndDate = sub.UnpaidPeriodEndDate; + GracePeriod = sub.GracePeriod; } public DateTime? TrialStartDate { get; set; } @@ -86,6 +90,10 @@ public class BillingSubscription public string Status { get; set; } public bool Cancelled { get; set; } public IEnumerable Items { get; set; } = new List(); + public string CollectionMethod { get; set; } + public DateTime? SuspensionDate { get; set; } + public DateTime? UnpaidPeriodEndDate { get; set; } + public int? GracePeriod { get; set; } public class BillingSubscriptionItem { diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index 9f94325513..63b1a3c3cd 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -170,8 +170,7 @@ public class Startup services.AddDefaultServices(globalSettings); services.AddOrganizationSubscriptionServices(); services.AddCoreLocalizationServices(); - services.AddBillingCommands(); - services.AddBillingQueries(); + services.AddBillingOperations(); // Authorization Handlers services.AddAuthorizationHandlers(); diff --git a/src/Api/appsettings.json b/src/Api/appsettings.json index e49491857f..c04539a9fe 100644 --- a/src/Api/appsettings.json +++ b/src/Api/appsettings.json @@ -32,10 +32,6 @@ "send": { "connectionString": "SECRET" }, - "documentDb": { - "uri": "SECRET", - "key": "SECRET" - }, "sentry": { "dsn": "SECRET" }, diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs index e78ed31ff1..679dea15ce 100644 --- a/src/Billing/Controllers/StripeController.cs +++ b/src/Billing/Controllers/StripeController.cs @@ -868,7 +868,7 @@ public class StripeController : Controller private bool UnpaidAutoChargeInvoiceForSubscriptionCycle(Invoice invoice) { return invoice.AmountDue > 0 && !invoice.Paid && invoice.CollectionMethod == "charge_automatically" && - invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null; + invoice.BillingReason is "subscription_cycle" or "automatic_pending_invoice_item_invoice" && invoice.SubscriptionId != null; } private async Task VerifyCorrectTaxRateForCharge(Invoice invoice, Subscription subscription) diff --git a/src/Billing/appsettings.json b/src/Billing/appsettings.json index 93d103aa80..4985784573 100644 --- a/src/Billing/appsettings.json +++ b/src/Billing/appsettings.json @@ -30,10 +30,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "documentDb": { - "uri": "SECRET", - "key": "SECRET" - }, "sentry": { "dsn": "SECRET" }, diff --git a/src/Core/AdminConsole/Entities/Provider/Provider.cs b/src/Core/AdminConsole/Entities/Provider/Provider.cs index ee2b35ed90..e5b794e6b1 100644 --- a/src/Core/AdminConsole/Entities/Provider/Provider.cs +++ b/src/Core/AdminConsole/Entities/Provider/Provider.cs @@ -6,7 +6,7 @@ using Bit.Core.Utilities; namespace Bit.Core.AdminConsole.Entities.Provider; -public class Provider : ITableObject +public class Provider : ITableObject, ISubscriber { public Guid Id { get; set; } /// @@ -34,6 +34,26 @@ public class Provider : ITableObject public string GatewayCustomerId { get; set; } public string GatewaySubscriptionId { get; set; } + public string BillingEmailAddress() => BillingEmail?.ToLowerInvariant().Trim(); + + public string BillingName() => DisplayBusinessName(); + + public string SubscriberName() => DisplayName(); + + public string BraintreeCustomerIdPrefix() => "p"; + + public string BraintreeIdField() => "provider_id"; + + public string BraintreeCloudRegionField() => "region"; + + public bool IsOrganization() => false; + + public bool IsUser() => false; + + public string SubscriberType() => "Provider"; + + public bool IsExpired() => false; + public void SetNewId() { if (Id == default) diff --git a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs index 39cc5a1d98..80a16e495a 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs @@ -146,7 +146,8 @@ public class SelfHostedOrganizationDetails : Organization OwnersNotifiedOfAutoscaling = OwnersNotifiedOfAutoscaling, LimitCollectionCreationDeletion = LimitCollectionCreationDeletion, AllowAdminAccessToAllCollectionItems = AllowAdminAccessToAllCollectionItems, - FlexibleCollections = FlexibleCollections + FlexibleCollections = FlexibleCollections, + Status = Status }; } } diff --git a/src/Core/Billing/BillingException.cs b/src/Core/Billing/BillingException.cs new file mode 100644 index 0000000000..a6944b3ed6 --- /dev/null +++ b/src/Core/Billing/BillingException.cs @@ -0,0 +1,9 @@ +namespace Bit.Core.Billing; + +public class BillingException( + string clientFriendlyMessage, + string internalMessage = null, + Exception innerException = null) : Exception(internalMessage, innerException) +{ + public string ClientFriendlyMessage { get; set; } = clientFriendlyMessage; +} diff --git a/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs b/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs new file mode 100644 index 0000000000..db21926bec --- /dev/null +++ b/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs @@ -0,0 +1,12 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; + +namespace Bit.Core.Billing.Commands; + +public interface IAssignSeatsToClientOrganizationCommand +{ + Task AssignSeatsToClientOrganization( + Provider provider, + Organization organization, + int seats); +} diff --git a/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs b/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs index b23880e650..88708d3d2e 100644 --- a/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs +++ b/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs @@ -1,7 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Models; using Bit.Core.Entities; -using Bit.Core.Exceptions; using Stripe; namespace Bit.Core.Billing.Commands; @@ -17,7 +16,6 @@ public interface ICancelSubscriptionCommand /// The or with the subscription to cancel. /// An DTO containing user-provided feedback on why they are cancelling the subscription. /// A flag indicating whether to cancel the subscription immediately or at the end of the subscription period. - /// Thrown when the provided subscription is already in an inactive state. Task CancelSubscription( Subscription subscription, OffboardingSurveyResponse offboardingSurveyResponse, diff --git a/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs b/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs index 62bf0d0926..e2be6f45eb 100644 --- a/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs +++ b/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs @@ -4,5 +4,12 @@ namespace Bit.Core.Billing.Commands; public interface IRemovePaymentMethodCommand { + /// + /// Attempts to remove an Organization's saved payment method. If the Stripe representing the + /// contains a valid "btCustomerId" key in its property, + /// this command will attempt to remove the Braintree . Otherwise, it will attempt to remove the + /// Stripe . + /// + /// The organization to remove the saved payment method for. Task RemovePaymentMethod(Organization organization); } diff --git a/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs b/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs new file mode 100644 index 0000000000..be2c6be968 --- /dev/null +++ b/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs @@ -0,0 +1,174 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Entities; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Repositories; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; +using static Bit.Core.Billing.Utilities; + +namespace Bit.Core.Billing.Commands.Implementations; + +public class AssignSeatsToClientOrganizationCommand( + ILogger logger, + IOrganizationRepository organizationRepository, + IPaymentService paymentService, + IProviderBillingQueries providerBillingQueries, + IProviderPlanRepository providerPlanRepository) : IAssignSeatsToClientOrganizationCommand +{ + public async Task AssignSeatsToClientOrganization( + Provider provider, + Organization organization, + int seats) + { + ArgumentNullException.ThrowIfNull(provider); + ArgumentNullException.ThrowIfNull(organization); + + if (provider.Type == ProviderType.Reseller) + { + logger.LogError("Reseller-type provider ({ID}) cannot assign seats to client organizations", provider.Id); + + throw ContactSupport("Consolidated billing does not support reseller-type providers"); + } + + if (seats < 0) + { + throw new BillingException( + "You cannot assign negative seats to a client.", + "MSP cannot assign negative seats to a client organization"); + } + + if (seats == organization.Seats) + { + logger.LogWarning("Client organization ({ID}) already has {Seats} seats assigned", organization.Id, organization.Seats); + + return; + } + + var providerPlan = await GetProviderPlanAsync(provider, organization); + + var providerSeatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0); + + // How many seats the provider has assigned to all their client organizations that have the specified plan type. + var providerCurrentlyAssignedSeatTotal = await providerBillingQueries.GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType); + + // How many seats are being added to or subtracted from this client organization. + var seatDifference = seats - (organization.Seats ?? 0); + + // How many seats the provider will have assigned to all of their client organizations after the update. + var providerNewlyAssignedSeatTotal = providerCurrentlyAssignedSeatTotal + seatDifference; + + var update = CurryUpdateFunction( + provider, + providerPlan, + organization, + seats, + providerNewlyAssignedSeatTotal); + + /* + * Below the limit => Below the limit: + * No subscription update required. We can safely update the organization's seats. + */ + if (providerCurrentlyAssignedSeatTotal <= providerSeatMinimum && + providerNewlyAssignedSeatTotal <= providerSeatMinimum) + { + organization.Seats = seats; + + await organizationRepository.ReplaceAsync(organization); + + providerPlan.AllocatedSeats = providerNewlyAssignedSeatTotal; + + await providerPlanRepository.ReplaceAsync(providerPlan); + } + /* + * Below the limit => Above the limit: + * We have to scale the subscription up from the seat minimum to the newly assigned seat total. + */ + else if (providerCurrentlyAssignedSeatTotal <= providerSeatMinimum && + providerNewlyAssignedSeatTotal > providerSeatMinimum) + { + await update( + providerSeatMinimum, + providerNewlyAssignedSeatTotal); + } + /* + * Above the limit => Above the limit: + * We have to scale the subscription from the currently assigned seat total to the newly assigned seat total. + */ + else if (providerCurrentlyAssignedSeatTotal > providerSeatMinimum && + providerNewlyAssignedSeatTotal > providerSeatMinimum) + { + await update( + providerCurrentlyAssignedSeatTotal, + providerNewlyAssignedSeatTotal); + } + /* + * Above the limit => Below the limit: + * We have to scale the subscription down from the currently assigned seat total to the seat minimum. + */ + else if (providerCurrentlyAssignedSeatTotal > providerSeatMinimum && + providerNewlyAssignedSeatTotal <= providerSeatMinimum) + { + await update( + providerCurrentlyAssignedSeatTotal, + providerSeatMinimum); + } + } + + // ReSharper disable once SuggestBaseTypeForParameter + private async Task GetProviderPlanAsync(Provider provider, Organization organization) + { + if (!organization.PlanType.SupportsConsolidatedBilling()) + { + logger.LogError("Cannot assign seats to a client organization ({ID}) with a plan type that does not support consolidated billing: {PlanType}", organization.Id, organization.PlanType); + + throw ContactSupport(); + } + + var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); + + var providerPlan = providerPlans.FirstOrDefault(providerPlan => providerPlan.PlanType == organization.PlanType); + + if (providerPlan != null && providerPlan.IsConfigured()) + { + return providerPlan; + } + + logger.LogError("Cannot assign seats to client organization ({ClientOrganizationID}) when provider's ({ProviderID}) matching plan is not configured", organization.Id, provider.Id); + + throw ContactSupport(); + } + + private Func CurryUpdateFunction( + Provider provider, + ProviderPlan providerPlan, + Organization organization, + int organizationNewlyAssignedSeats, + int providerNewlyAssignedSeats) => async (providerCurrentlySubscribedSeats, providerNewlySubscribedSeats) => + { + var plan = StaticStore.GetPlan(providerPlan.PlanType); + + await paymentService.AdjustSeats( + provider, + plan, + providerCurrentlySubscribedSeats, + providerNewlySubscribedSeats); + + organization.Seats = organizationNewlyAssignedSeats; + + await organizationRepository.ReplaceAsync(organization); + + var providerNewlyPurchasedSeats = providerNewlySubscribedSeats > providerPlan.SeatMinimum + ? providerNewlySubscribedSeats - providerPlan.SeatMinimum + : 0; + + providerPlan.PurchasedSeats = providerNewlyPurchasedSeats; + providerPlan.AllocatedSeats = providerNewlyAssignedSeats; + + await providerPlanRepository.ReplaceAsync(providerPlan); + }; +} diff --git a/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs b/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs index c5dbb6d927..be8479ea99 100644 --- a/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs +++ b/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs @@ -1,55 +1,41 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Enums; -using Bit.Core.Exceptions; using Bit.Core.Services; using Braintree; using Microsoft.Extensions.Logging; +using static Bit.Core.Billing.Utilities; + namespace Bit.Core.Billing.Commands.Implementations; -public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand +public class RemovePaymentMethodCommand( + IBraintreeGateway braintreeGateway, + ILogger logger, + IStripeAdapter stripeAdapter) + : IRemovePaymentMethodCommand { - private readonly IBraintreeGateway _braintreeGateway; - private readonly ILogger _logger; - private readonly IStripeAdapter _stripeAdapter; - - public RemovePaymentMethodCommand( - IBraintreeGateway braintreeGateway, - ILogger logger, - IStripeAdapter stripeAdapter) - { - _braintreeGateway = braintreeGateway; - _logger = logger; - _stripeAdapter = stripeAdapter; - } - public async Task RemovePaymentMethod(Organization organization) { - const string braintreeCustomerIdKey = "btCustomerId"; - - if (organization == null) - { - throw new ArgumentNullException(nameof(organization)); - } + ArgumentNullException.ThrowIfNull(organization); if (organization.Gateway is not GatewayType.Stripe || string.IsNullOrEmpty(organization.GatewayCustomerId)) { throw ContactSupport(); } - var stripeCustomer = await _stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, new Stripe.CustomerGetOptions + var stripeCustomer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, new Stripe.CustomerGetOptions { - Expand = new List { "invoice_settings.default_payment_method", "sources" } + Expand = ["invoice_settings.default_payment_method", "sources"] }); if (stripeCustomer == null) { - _logger.LogError("Could not find Stripe customer ({ID}) when removing payment method", organization.GatewayCustomerId); + logger.LogError("Could not find Stripe customer ({ID}) when removing payment method", organization.GatewayCustomerId); throw ContactSupport(); } - if (stripeCustomer.Metadata?.TryGetValue(braintreeCustomerIdKey, out var braintreeCustomerId) ?? false) + if (stripeCustomer.Metadata?.TryGetValue(BraintreeCustomerIdKey, out var braintreeCustomerId) ?? false) { await RemoveBraintreePaymentMethodAsync(braintreeCustomerId); } @@ -61,11 +47,11 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand private async Task RemoveBraintreePaymentMethodAsync(string braintreeCustomerId) { - var customer = await _braintreeGateway.Customer.FindAsync(braintreeCustomerId); + var customer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); if (customer == null) { - _logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId); + logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId); throw ContactSupport(); } @@ -74,27 +60,27 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand { var existingDefaultPaymentMethod = customer.DefaultPaymentMethod; - var updateCustomerResult = await _braintreeGateway.Customer.UpdateAsync( + var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync( braintreeCustomerId, new CustomerRequest { DefaultPaymentMethodToken = null }); if (!updateCustomerResult.IsSuccess()) { - _logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}", + logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}", braintreeCustomerId, updateCustomerResult.Message); throw ContactSupport(); } - var deletePaymentMethodResult = await _braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token); + var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token); if (!deletePaymentMethodResult.IsSuccess()) { - await _braintreeGateway.Customer.UpdateAsync( + await braintreeGateway.Customer.UpdateAsync( braintreeCustomerId, new CustomerRequest { DefaultPaymentMethodToken = existingDefaultPaymentMethod.Token }); - _logger.LogError( + logger.LogError( "Failed to delete Braintree payment method for Customer ({ID}), re-linked payment method. Message: {Message}", braintreeCustomerId, deletePaymentMethodResult.Message); @@ -103,7 +89,7 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand } else { - _logger.LogWarning("Tried to remove non-existent Braintree payment method for Customer ({ID})", braintreeCustomerId); + logger.LogWarning("Tried to remove non-existent Braintree payment method for Customer ({ID})", braintreeCustomerId); } } @@ -116,25 +102,23 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand switch (source) { case Stripe.BankAccount: - await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); + await stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); break; case Stripe.Card: - await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id); + await stripeAdapter.CardDeleteAsync(customer.Id, source.Id); break; } } } - var paymentMethods = _stripeAdapter.PaymentMethodListAutoPagingAsync(new Stripe.PaymentMethodListOptions + var paymentMethods = stripeAdapter.PaymentMethodListAutoPagingAsync(new Stripe.PaymentMethodListOptions { Customer = customer.Id }); await foreach (var paymentMethod in paymentMethods) { - await _stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id, new Stripe.PaymentMethodDetachOptions()); + await stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id, new Stripe.PaymentMethodDetachOptions()); } } - - private static GatewayException ContactSupport() => new("Could not remove your payment method. Please contact support for assistance."); } diff --git a/src/Core/Billing/Entities/ProviderPlan.cs b/src/Core/Billing/Entities/ProviderPlan.cs index 325dbbb156..f4965570d9 100644 --- a/src/Core/Billing/Entities/ProviderPlan.cs +++ b/src/Core/Billing/Entities/ProviderPlan.cs @@ -20,4 +20,6 @@ public class ProviderPlan : ITableObject Id = CoreHelpers.GenerateComb(); } } + + public bool IsConfigured() => SeatMinimum.HasValue && PurchasedSeats.HasValue && AllocatedSeats.HasValue; } diff --git a/src/Core/Billing/Extensions/BillingExtensions.cs b/src/Core/Billing/Extensions/BillingExtensions.cs new file mode 100644 index 0000000000..c7abeb81e2 --- /dev/null +++ b/src/Core/Billing/Extensions/BillingExtensions.cs @@ -0,0 +1,9 @@ +using Bit.Core.Enums; + +namespace Bit.Core.Billing.Extensions; + +public static class BillingExtensions +{ + public static bool SupportsConsolidatedBilling(this PlanType planType) + => planType is PlanType.TeamsMonthly or PlanType.EnterpriseMonthly; +} diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 113fa4d5b7..8e28b23397 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -9,14 +9,15 @@ using Microsoft.Extensions.DependencyInjection; public static class ServiceCollectionExtensions { - public static void AddBillingCommands(this IServiceCollection services) + public static void AddBillingOperations(this IServiceCollection services) { - services.AddSingleton(); - services.AddSingleton(); - } + // Queries + services.AddTransient(); + services.AddTransient(); - public static void AddBillingQueries(this IServiceCollection services) - { - services.AddSingleton(); + // Commands + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); } } diff --git a/src/Core/Billing/Models/ConfiguredProviderPlan.cs b/src/Core/Billing/Models/ConfiguredProviderPlan.cs new file mode 100644 index 0000000000..d6bc2b7522 --- /dev/null +++ b/src/Core/Billing/Models/ConfiguredProviderPlan.cs @@ -0,0 +1,24 @@ +using Bit.Core.Billing.Entities; +using Bit.Core.Enums; + +namespace Bit.Core.Billing.Models; + +public record ConfiguredProviderPlan( + Guid Id, + Guid ProviderId, + PlanType PlanType, + int SeatMinimum, + int PurchasedSeats, + int AssignedSeats) +{ + public static ConfiguredProviderPlan From(ProviderPlan providerPlan) => + providerPlan.IsConfigured() + ? new ConfiguredProviderPlan( + providerPlan.Id, + providerPlan.ProviderId, + providerPlan.PlanType, + providerPlan.SeatMinimum.GetValueOrDefault(0), + providerPlan.PurchasedSeats.GetValueOrDefault(0), + providerPlan.AllocatedSeats.GetValueOrDefault(0)) + : null; +} diff --git a/src/Core/Billing/Models/ProviderSubscriptionData.cs b/src/Core/Billing/Models/ProviderSubscriptionData.cs new file mode 100644 index 0000000000..27da6cd226 --- /dev/null +++ b/src/Core/Billing/Models/ProviderSubscriptionData.cs @@ -0,0 +1,7 @@ +using Stripe; + +namespace Bit.Core.Billing.Models; + +public record ProviderSubscriptionData( + List ProviderPlans, + Subscription Subscription); diff --git a/src/Core/Billing/Queries/IGetSubscriptionQuery.cs b/src/Core/Billing/Queries/IGetSubscriptionQuery.cs deleted file mode 100644 index 9ba2a85ed5..0000000000 --- a/src/Core/Billing/Queries/IGetSubscriptionQuery.cs +++ /dev/null @@ -1,18 +0,0 @@ -using Bit.Core.Entities; -using Bit.Core.Exceptions; -using Stripe; - -namespace Bit.Core.Billing.Queries; - -public interface IGetSubscriptionQuery -{ - /// - /// Retrieves a Stripe using the 's property. - /// - /// The organization or user to retrieve the subscription for. - /// A Stripe . - /// Thrown when the is . - /// Thrown when the subscriber's is or empty. - /// Thrown when the returned from Stripe's API is null. - Task GetSubscription(ISubscriber subscriber); -} diff --git a/src/Core/Billing/Queries/IProviderBillingQueries.cs b/src/Core/Billing/Queries/IProviderBillingQueries.cs new file mode 100644 index 0000000000..e4b7d0f14d --- /dev/null +++ b/src/Core/Billing/Queries/IProviderBillingQueries.cs @@ -0,0 +1,27 @@ +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Models; +using Bit.Core.Enums; + +namespace Bit.Core.Billing.Queries; + +public interface IProviderBillingQueries +{ + /// + /// Retrieves the number of seats an MSP has assigned to its client organizations with a specified . + /// + /// The ID of the MSP to retrieve the assigned seat total for. + /// The type of plan to retrieve the assigned seat total for. + /// An representing the number of seats the provider has assigned to its client organizations with the specified . + /// Thrown when the provider represented by the is . + /// Thrown when the provider represented by the has . + Task GetAssignedSeatTotalForPlanOrThrow(Guid providerId, PlanType planType); + + /// + /// Retrieves a provider's billing subscription data. + /// + /// The ID of the provider to retrieve subscription data for. + /// A object containing the provider's Stripe and their s. + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetSubscriptionData(Guid providerId); +} diff --git a/src/Core/Billing/Queries/ISubscriberQueries.cs b/src/Core/Billing/Queries/ISubscriberQueries.cs new file mode 100644 index 0000000000..ea6c0d985e --- /dev/null +++ b/src/Core/Billing/Queries/ISubscriberQueries.cs @@ -0,0 +1,30 @@ +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Stripe; + +namespace Bit.Core.Billing.Queries; + +public interface ISubscriberQueries +{ + /// + /// Retrieves a Stripe using the 's property. + /// + /// The organization, provider or user to retrieve the subscription for. + /// Optional parameters that can be passed to Stripe to expand or modify the . + /// A Stripe . + /// Thrown when the is . + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetSubscription( + ISubscriber subscriber, + SubscriptionGetOptions subscriptionGetOptions = null); + + /// + /// Retrieves a Stripe using the 's property. + /// + /// The organization or user to retrieve the subscription for. + /// A Stripe . + /// Thrown when the is . + /// Thrown when the subscriber's is or empty. + /// Thrown when the returned from Stripe's API is null. + Task GetSubscriptionOrThrow(ISubscriber subscriber); +} diff --git a/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs b/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs deleted file mode 100644 index c3b0a29552..0000000000 --- a/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs +++ /dev/null @@ -1,36 +0,0 @@ -using Bit.Core.Entities; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; - -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Queries.Implementations; - -public class GetSubscriptionQuery( - ILogger logger, - IStripeAdapter stripeAdapter) : IGetSubscriptionQuery -{ - public async Task GetSubscription(ISubscriber subscriber) - { - ArgumentNullException.ThrowIfNull(subscriber); - - if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) - { - logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); - - throw ContactSupport(); - } - - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - - if (subscription != null) - { - return subscription; - } - - logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); - - throw ContactSupport(); - } -} diff --git a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs new file mode 100644 index 0000000000..f8bff9d3fd --- /dev/null +++ b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs @@ -0,0 +1,92 @@ +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Repositories; +using Bit.Core.Enums; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; +using Stripe; +using static Bit.Core.Billing.Utilities; + +namespace Bit.Core.Billing.Queries.Implementations; + +public class ProviderBillingQueries( + ILogger logger, + IProviderOrganizationRepository providerOrganizationRepository, + IProviderPlanRepository providerPlanRepository, + IProviderRepository providerRepository, + ISubscriberQueries subscriberQueries) : IProviderBillingQueries +{ + public async Task GetAssignedSeatTotalForPlanOrThrow( + Guid providerId, + PlanType planType) + { + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + logger.LogError( + "Could not find provider ({ID}) when retrieving assigned seat total", + providerId); + + throw ContactSupport(); + } + + if (provider.Type == ProviderType.Reseller) + { + logger.LogError("Assigned seats cannot be retrieved for reseller-type provider ({ID})", providerId); + + throw ContactSupport("Consolidated billing does not support reseller-type providers"); + } + + var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); + + var plan = StaticStore.GetPlan(planType); + + return providerOrganizations + .Where(providerOrganization => providerOrganization.Plan == plan.Name) + .Sum(providerOrganization => providerOrganization.Seats ?? 0); + } + + public async Task GetSubscriptionData(Guid providerId) + { + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + logger.LogError( + "Could not find provider ({ID}) when retrieving subscription data.", + providerId); + + return null; + } + + if (provider.Type == ProviderType.Reseller) + { + logger.LogError("Subscription data cannot be retrieved for reseller-type provider ({ID})", providerId); + + throw ContactSupport("Consolidated billing does not support reseller-type providers"); + } + + var subscription = await subscriberQueries.GetSubscription(provider, new SubscriptionGetOptions + { + Expand = ["customer"] + }); + + if (subscription == null) + { + return null; + } + + var providerPlans = await providerPlanRepository.GetByProviderId(providerId); + + var configuredProviderPlans = providerPlans + .Where(providerPlan => providerPlan.IsConfigured()) + .Select(ConfiguredProviderPlan.From) + .ToList(); + + return new ProviderSubscriptionData( + configuredProviderPlans, + subscription); + } +} diff --git a/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs b/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs new file mode 100644 index 0000000000..a160a87595 --- /dev/null +++ b/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs @@ -0,0 +1,61 @@ +using Bit.Core.Entities; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using Stripe; + +using static Bit.Core.Billing.Utilities; + +namespace Bit.Core.Billing.Queries.Implementations; + +public class SubscriberQueries( + ILogger logger, + IStripeAdapter stripeAdapter) : ISubscriberQueries +{ + public async Task GetSubscription( + ISubscriber subscriber, + SubscriptionGetOptions subscriptionGetOptions = null) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); + + return null; + } + + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + + if (subscription != null) + { + return subscription; + } + + logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); + + return null; + } + + public async Task GetSubscriptionOrThrow(ISubscriber subscriber) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); + + throw ContactSupport(); + } + + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + + if (subscription != null) + { + return subscription; + } + + logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); + + throw ContactSupport(); + } +} diff --git a/src/Core/Billing/Repositories/IProviderPlanRepository.cs b/src/Core/Billing/Repositories/IProviderPlanRepository.cs index ccfc6ee683..eccbad82bb 100644 --- a/src/Core/Billing/Repositories/IProviderPlanRepository.cs +++ b/src/Core/Billing/Repositories/IProviderPlanRepository.cs @@ -5,5 +5,5 @@ namespace Bit.Core.Billing.Repositories; public interface IProviderPlanRepository : IRepository { - Task GetByProviderId(Guid providerId); + Task> GetByProviderId(Guid providerId); } diff --git a/src/Core/Billing/Utilities.cs b/src/Core/Billing/Utilities.cs index 54ace07a70..2b06f1ea6c 100644 --- a/src/Core/Billing/Utilities.cs +++ b/src/Core/Billing/Utilities.cs @@ -1,8 +1,11 @@ -using Bit.Core.Exceptions; - -namespace Bit.Core.Billing; +namespace Bit.Core.Billing; public static class Utilities { - public static GatewayException ContactSupport() => new("Something went wrong with your request. Please contact support."); + public const string BraintreeCustomerIdKey = "btCustomerId"; + + public static BillingException ContactSupport( + string internalMessage = null, + Exception innerException = null) => new("Something went wrong with your request. Please contact support.", + internalMessage, innerException); } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index e7685891ad..6edca0c505 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -114,7 +114,6 @@ public static class FeatureFlagKeys /// public const string FlexibleCollections = "flexible-collections-disabled-do-not-use"; public const string FlexibleCollectionsV1 = "flexible-collections-v-1"; // v-1 is intentional - public const string BulkCollectionAccess = "bulk-collection-access"; public const string ItemShare = "item-share"; public const string KeyRotationImprovements = "key-rotation-improvements"; public const string DuoRedirect = "duo-redirect"; @@ -131,6 +130,8 @@ public static class FeatureFlagKeys public const string PM5864DollarThreshold = "PM-5864-dollar-threshold"; public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email"; public const string ShowPaymentMethodWarningBanners = "show-payment-method-warning-banners"; + public const string EnableConsolidatedBilling = "enable-consolidated-billing"; + public const string AC1795_UpdatedSubscriptionStatusSection = "AC-1795_updated-subscription-status-section"; public static List GetAllKeys() { diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index ff3c632b5b..3e77b5d105 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -21,8 +21,9 @@ - - + + + @@ -35,9 +36,8 @@ - - + @@ -50,10 +50,10 @@ - + - + diff --git a/src/Core/Models/Business/CompleteSubscriptionUpdate.cs b/src/Core/Models/Business/CompleteSubscriptionUpdate.cs index a1146cd2a0..aa1c92dc2e 100644 --- a/src/Core/Models/Business/CompleteSubscriptionUpdate.cs +++ b/src/Core/Models/Business/CompleteSubscriptionUpdate.cs @@ -1,5 +1,4 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.Enums; using Bit.Core.Exceptions; using Stripe; @@ -279,25 +278,6 @@ public class CompleteSubscriptionUpdate : SubscriptionUpdate }; } - private static SubscriptionItem FindSubscriptionItem(Subscription subscription, string planId) - { - if (string.IsNullOrEmpty(planId)) - { - return null; - } - - var data = subscription.Items.Data; - - var subscriptionItem = data.FirstOrDefault(item => item.Plan?.Id == planId) ?? data.FirstOrDefault(item => item.Price?.Id == planId); - - return subscriptionItem; - } - - private static string GetPasswordManagerPlanId(StaticStore.Plan plan) - => IsNonSeatBasedPlan(plan) - ? plan.PasswordManager.StripePlanId - : plan.PasswordManager.StripeSeatPlanId; - private static SubscriptionData GetSubscriptionDataFor(Organization organization) { var plan = Utilities.StaticStore.GetPlan(organization.PlanType); @@ -320,10 +300,4 @@ public class CompleteSubscriptionUpdate : SubscriptionUpdate 0 }; } - - private static bool IsNonSeatBasedPlan(StaticStore.Plan plan) - => plan.Type is - >= PlanType.FamiliesAnnually2019 and <= PlanType.EnterpriseAnnually2019 - or PlanType.FamiliesAnnually - or PlanType.TeamsStarter; } diff --git a/src/Core/Models/Business/ProviderSubscriptionUpdate.cs b/src/Core/Models/Business/ProviderSubscriptionUpdate.cs new file mode 100644 index 0000000000..8b29bebce5 --- /dev/null +++ b/src/Core/Models/Business/ProviderSubscriptionUpdate.cs @@ -0,0 +1,61 @@ +using Bit.Core.Billing.Extensions; +using Bit.Core.Enums; +using Stripe; + +using static Bit.Core.Billing.Utilities; + +namespace Bit.Core.Models.Business; + +public class ProviderSubscriptionUpdate : SubscriptionUpdate +{ + private readonly string _planId; + private readonly int _previouslyPurchasedSeats; + private readonly int _newlyPurchasedSeats; + + protected override List PlanIds => [_planId]; + + public ProviderSubscriptionUpdate( + PlanType planType, + int previouslyPurchasedSeats, + int newlyPurchasedSeats) + { + if (!planType.SupportsConsolidatedBilling()) + { + throw ContactSupport($"Cannot create a {nameof(ProviderSubscriptionUpdate)} for {nameof(PlanType)} that doesn't support consolidated billing"); + } + + _planId = GetPasswordManagerPlanId(Utilities.StaticStore.GetPlan(planType)); + _previouslyPurchasedSeats = previouslyPurchasedSeats; + _newlyPurchasedSeats = newlyPurchasedSeats; + } + + public override List RevertItemsOptions(Subscription subscription) + { + var subscriptionItem = FindSubscriptionItem(subscription, _planId); + + return + [ + new SubscriptionItemOptions + { + Id = subscriptionItem.Id, + Price = _planId, + Quantity = _previouslyPurchasedSeats + } + ]; + } + + public override List UpgradeItemsOptions(Subscription subscription) + { + var subscriptionItem = FindSubscriptionItem(subscription, _planId); + + return + [ + new SubscriptionItemOptions + { + Id = subscriptionItem.Id, + Price = _planId, + Quantity = _newlyPurchasedSeats + } + ]; + } +} diff --git a/src/Core/Models/Business/SeatSubscriptionUpdate.cs b/src/Core/Models/Business/SeatSubscriptionUpdate.cs index c5ea1a7474..db5104ddd2 100644 --- a/src/Core/Models/Business/SeatSubscriptionUpdate.cs +++ b/src/Core/Models/Business/SeatSubscriptionUpdate.cs @@ -18,7 +18,7 @@ public class SeatSubscriptionUpdate : SubscriptionUpdate public override List UpgradeItemsOptions(Subscription subscription) { - var item = SubscriptionItem(subscription, PlanIds.Single()); + var item = FindSubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions @@ -34,7 +34,7 @@ public class SeatSubscriptionUpdate : SubscriptionUpdate public override List RevertItemsOptions(Subscription subscription) { - var item = SubscriptionItem(subscription, PlanIds.Single()); + var item = FindSubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions diff --git a/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs b/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs index c93212eac8..c3e3e09992 100644 --- a/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs +++ b/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs @@ -19,7 +19,7 @@ public class ServiceAccountSubscriptionUpdate : SubscriptionUpdate public override List UpgradeItemsOptions(Subscription subscription) { - var item = SubscriptionItem(subscription, PlanIds.Single()); + var item = FindSubscriptionItem(subscription, PlanIds.Single()); _prevServiceAccounts = item?.Quantity ?? 0; return new() { @@ -35,7 +35,7 @@ public class ServiceAccountSubscriptionUpdate : SubscriptionUpdate public override List RevertItemsOptions(Subscription subscription) { - var item = SubscriptionItem(subscription, PlanIds.Single()); + var item = FindSubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions diff --git a/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs b/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs index ff6bb55011..b8201b9775 100644 --- a/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs +++ b/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs @@ -19,7 +19,7 @@ public class SmSeatSubscriptionUpdate : SubscriptionUpdate public override List UpgradeItemsOptions(Subscription subscription) { - var item = SubscriptionItem(subscription, PlanIds.Single()); + var item = FindSubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions @@ -35,7 +35,7 @@ public class SmSeatSubscriptionUpdate : SubscriptionUpdate public override List RevertItemsOptions(Subscription subscription) { - var item = SubscriptionItem(subscription, PlanIds.Single()); + var item = FindSubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions diff --git a/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs b/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs index 88af72f199..59a745297b 100644 --- a/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs +++ b/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs @@ -74,10 +74,10 @@ public class SponsorOrganizationSubscriptionUpdate : SubscriptionUpdate private string AddStripePlanId => _applySponsorship ? _sponsoredPlanStripeId : _existingPlanStripeId; private Stripe.SubscriptionItem RemoveStripeItem(Subscription subscription) => _applySponsorship ? - SubscriptionItem(subscription, _existingPlanStripeId) : - SubscriptionItem(subscription, _sponsoredPlanStripeId); + FindSubscriptionItem(subscription, _existingPlanStripeId) : + FindSubscriptionItem(subscription, _sponsoredPlanStripeId); private Stripe.SubscriptionItem AddStripeItem(Subscription subscription) => _applySponsorship ? - SubscriptionItem(subscription, _sponsoredPlanStripeId) : - SubscriptionItem(subscription, _existingPlanStripeId); + FindSubscriptionItem(subscription, _sponsoredPlanStripeId) : + FindSubscriptionItem(subscription, _existingPlanStripeId); } diff --git a/src/Core/Models/Business/StorageSubscriptionUpdate.cs b/src/Core/Models/Business/StorageSubscriptionUpdate.cs index 30ab2428e2..b0f4a83d3e 100644 --- a/src/Core/Models/Business/StorageSubscriptionUpdate.cs +++ b/src/Core/Models/Business/StorageSubscriptionUpdate.cs @@ -17,7 +17,7 @@ public class StorageSubscriptionUpdate : SubscriptionUpdate public override List UpgradeItemsOptions(Subscription subscription) { - var item = SubscriptionItem(subscription, PlanIds.Single()); + var item = FindSubscriptionItem(subscription, PlanIds.Single()); _prevStorage = item?.Quantity ?? 0; return new() { @@ -38,7 +38,7 @@ public class StorageSubscriptionUpdate : SubscriptionUpdate throw new Exception("Unknown previous value, must first call UpgradeItemsOptions"); } - var item = SubscriptionItem(subscription, PlanIds.Single()); + var item = FindSubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions diff --git a/src/Core/Models/Business/SubscriptionInfo.cs b/src/Core/Models/Business/SubscriptionInfo.cs index 23f8f95278..7bb5bddbc8 100644 --- a/src/Core/Models/Business/SubscriptionInfo.cs +++ b/src/Core/Models/Business/SubscriptionInfo.cs @@ -43,6 +43,9 @@ public class SubscriptionInfo Items = sub.Items.Data.Select(i => new BillingSubscriptionItem(i)); } CollectionMethod = sub.CollectionMethod; + GracePeriod = sub.CollectionMethod == "charge_automatically" + ? 14 + : 30; } public DateTime? TrialStartDate { get; set; } @@ -56,6 +59,9 @@ public class SubscriptionInfo public bool Cancelled { get; set; } public IEnumerable Items { get; set; } = new List(); public string CollectionMethod { get; set; } + public DateTime? SuspensionDate { get; set; } + public DateTime? UnpaidPeriodEndDate { get; set; } + public int GracePeriod { get; set; } public class BillingSubscriptionItem { diff --git a/src/Core/Models/Business/SubscriptionUpdate.cs b/src/Core/Models/Business/SubscriptionUpdate.cs index 70106a10ea..bba9d384d2 100644 --- a/src/Core/Models/Business/SubscriptionUpdate.cs +++ b/src/Core/Models/Business/SubscriptionUpdate.cs @@ -1,4 +1,5 @@ -using Stripe; +using Bit.Core.Enums; +using Stripe; namespace Bit.Core.Models.Business; @@ -15,7 +16,7 @@ public abstract class SubscriptionUpdate foreach (var upgradeItemOptions in upgradeItemsOptions) { var upgradeQuantity = upgradeItemOptions.Quantity ?? 0; - var existingQuantity = SubscriptionItem(subscription, upgradeItemOptions.Plan)?.Quantity ?? 0; + var existingQuantity = FindSubscriptionItem(subscription, upgradeItemOptions.Plan)?.Quantity ?? 0; if (upgradeQuantity != existingQuantity) { return true; @@ -24,6 +25,28 @@ public abstract class SubscriptionUpdate return false; } - protected static SubscriptionItem SubscriptionItem(Subscription subscription, string planId) => - planId == null ? null : subscription.Items?.Data?.FirstOrDefault(i => i.Plan.Id == planId); + protected static SubscriptionItem FindSubscriptionItem(Subscription subscription, string planId) + { + if (string.IsNullOrEmpty(planId)) + { + return null; + } + + var data = subscription.Items.Data; + + var subscriptionItem = data.FirstOrDefault(item => item.Plan?.Id == planId) ?? data.FirstOrDefault(item => item.Price?.Id == planId); + + return subscriptionItem; + } + + protected static string GetPasswordManagerPlanId(StaticStore.Plan plan) + => IsNonSeatBasedPlan(plan) + ? plan.PasswordManager.StripePlanId + : plan.PasswordManager.StripeSeatPlanId; + + protected static bool IsNonSeatBasedPlan(StaticStore.Plan plan) + => plan.Type is + >= PlanType.FamiliesAnnually2019 and <= PlanType.EnterpriseAnnually2019 + or PlanType.FamiliesAnnually + or PlanType.TeamsStarter; } diff --git a/src/Core/Models/Data/DictionaryEntity.cs b/src/Core/Models/Data/DictionaryEntity.cs deleted file mode 100644 index 72e6c871c7..0000000000 --- a/src/Core/Models/Data/DictionaryEntity.cs +++ /dev/null @@ -1,134 +0,0 @@ -using System.Collections; -using Microsoft.Azure.Cosmos.Table; - -namespace Bit.Core.Models.Data; - -public class DictionaryEntity : TableEntity, IDictionary -{ - private IDictionary _properties = new Dictionary(); - - public ICollection Values => _properties.Values; - - public EntityProperty this[string key] - { - get => _properties[key]; - set => _properties[key] = value; - } - - public int Count => _properties.Count; - - public bool IsReadOnly => _properties.IsReadOnly; - - public ICollection Keys => _properties.Keys; - - public override void ReadEntity(IDictionary properties, - OperationContext operationContext) - { - _properties = properties; - } - - public override IDictionary WriteEntity(OperationContext operationContext) - { - return _properties; - } - - public void Add(string key, EntityProperty value) - { - _properties.Add(key, value); - } - - public void Add(string key, bool value) - { - _properties.Add(key, new EntityProperty(value)); - } - - public void Add(string key, byte[] value) - { - _properties.Add(key, new EntityProperty(value)); - } - - public void Add(string key, DateTime? value) - { - _properties.Add(key, new EntityProperty(value)); - } - - public void Add(string key, DateTimeOffset? value) - { - _properties.Add(key, new EntityProperty(value)); - } - - public void Add(string key, double value) - { - _properties.Add(key, new EntityProperty(value)); - } - - public void Add(string key, Guid value) - { - _properties.Add(key, new EntityProperty(value)); - } - - public void Add(string key, int value) - { - _properties.Add(key, new EntityProperty(value)); - } - - public void Add(string key, long value) - { - _properties.Add(key, new EntityProperty(value)); - } - - public void Add(string key, string value) - { - _properties.Add(key, new EntityProperty(value)); - } - - public void Add(KeyValuePair item) - { - _properties.Add(item); - } - - public bool ContainsKey(string key) - { - return _properties.ContainsKey(key); - } - - public bool Remove(string key) - { - return _properties.Remove(key); - } - - public bool TryGetValue(string key, out EntityProperty value) - { - return _properties.TryGetValue(key, out value); - } - - public void Clear() - { - _properties.Clear(); - } - - public bool Contains(KeyValuePair item) - { - return _properties.Contains(item); - } - - public void CopyTo(KeyValuePair[] array, int arrayIndex) - { - _properties.CopyTo(array, arrayIndex); - } - - public bool Remove(KeyValuePair item) - { - return _properties.Remove(item); - } - - public IEnumerator> GetEnumerator() - { - return _properties.GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return _properties.GetEnumerator(); - } -} diff --git a/src/Core/Models/Data/EventTableEntity.cs b/src/Core/Models/Data/EventTableEntity.cs index df4a85acaf..69365f4127 100644 --- a/src/Core/Models/Data/EventTableEntity.cs +++ b/src/Core/Models/Data/EventTableEntity.cs @@ -1,10 +1,73 @@ -using Bit.Core.Enums; +using Azure; +using Azure.Data.Tables; +using Bit.Core.Enums; using Bit.Core.Utilities; -using Microsoft.Azure.Cosmos.Table; namespace Bit.Core.Models.Data; -public class EventTableEntity : TableEntity, IEvent +// used solely for interaction with Azure Table Storage +public class AzureEvent : ITableEntity +{ + public string PartitionKey { get; set; } + public string RowKey { get; set; } + public DateTimeOffset? Timestamp { get; set; } + public ETag ETag { get; set; } + + public DateTime Date { get; set; } + public int Type { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Guid? InstallationId { get; set; } + public Guid? ProviderId { get; set; } + public Guid? CipherId { get; set; } + public Guid? CollectionId { get; set; } + public Guid? PolicyId { get; set; } + public Guid? GroupId { get; set; } + public Guid? OrganizationUserId { get; set; } + public Guid? ProviderUserId { get; set; } + public Guid? ProviderOrganizationId { get; set; } + public int? DeviceType { get; set; } + public string IpAddress { get; set; } + public Guid? ActingUserId { get; set; } + public int? SystemUser { get; set; } + public string DomainName { get; set; } + public Guid? SecretId { get; set; } + public Guid? ServiceAccountId { get; set; } + + public EventTableEntity ToEventTableEntity() + { + return new EventTableEntity + { + PartitionKey = PartitionKey, + RowKey = RowKey, + Timestamp = Timestamp, + ETag = ETag, + + Date = Date, + Type = (EventType)Type, + UserId = UserId, + OrganizationId = OrganizationId, + InstallationId = InstallationId, + ProviderId = ProviderId, + CipherId = CipherId, + CollectionId = CollectionId, + PolicyId = PolicyId, + GroupId = GroupId, + OrganizationUserId = OrganizationUserId, + ProviderUserId = ProviderUserId, + ProviderOrganizationId = ProviderOrganizationId, + DeviceType = DeviceType.HasValue ? (DeviceType)DeviceType.Value : null, + IpAddress = IpAddress, + ActingUserId = ActingUserId, + SystemUser = SystemUser.HasValue ? (EventSystemUser)SystemUser.Value : null, + DomainName = DomainName, + SecretId = SecretId, + ServiceAccountId = ServiceAccountId + }; + } +} + +public class EventTableEntity : IEvent { public EventTableEntity() { } @@ -32,6 +95,11 @@ public class EventTableEntity : TableEntity, IEvent ServiceAccountId = e.ServiceAccountId; } + public string PartitionKey { get; set; } + public string RowKey { get; set; } + public DateTimeOffset? Timestamp { get; set; } + public ETag ETag { get; set; } + public DateTime Date { get; set; } public EventType Type { get; set; } public Guid? UserId { get; set; } @@ -53,65 +121,36 @@ public class EventTableEntity : TableEntity, IEvent public Guid? SecretId { get; set; } public Guid? ServiceAccountId { get; set; } - public override IDictionary WriteEntity(OperationContext operationContext) + public AzureEvent ToAzureEvent() { - var result = base.WriteEntity(operationContext); + return new AzureEvent + { + PartitionKey = PartitionKey, + RowKey = RowKey, + Timestamp = Timestamp, + ETag = ETag, - var typeName = nameof(Type); - if (result.ContainsKey(typeName)) - { - result[typeName] = new EntityProperty((int)Type); - } - else - { - result.Add(typeName, new EntityProperty((int)Type)); - } - - var deviceTypeName = nameof(DeviceType); - if (result.ContainsKey(deviceTypeName)) - { - result[deviceTypeName] = new EntityProperty((int?)DeviceType); - } - else - { - result.Add(deviceTypeName, new EntityProperty((int?)DeviceType)); - } - - var systemUserTypeName = nameof(SystemUser); - if (result.ContainsKey(systemUserTypeName)) - { - result[systemUserTypeName] = new EntityProperty((int?)SystemUser); - } - else - { - result.Add(systemUserTypeName, new EntityProperty((int?)SystemUser)); - } - - return result; - } - - public override void ReadEntity(IDictionary properties, - OperationContext operationContext) - { - base.ReadEntity(properties, operationContext); - - var typeName = nameof(Type); - if (properties.ContainsKey(typeName) && properties[typeName].Int32Value.HasValue) - { - Type = (EventType)properties[typeName].Int32Value.Value; - } - - var deviceTypeName = nameof(DeviceType); - if (properties.ContainsKey(deviceTypeName) && properties[deviceTypeName].Int32Value.HasValue) - { - DeviceType = (DeviceType)properties[deviceTypeName].Int32Value.Value; - } - - var systemUserTypeName = nameof(SystemUser); - if (properties.ContainsKey(systemUserTypeName) && properties[systemUserTypeName].Int32Value.HasValue) - { - SystemUser = (EventSystemUser)properties[systemUserTypeName].Int32Value.Value; - } + Date = Date, + Type = (int)Type, + UserId = UserId, + OrganizationId = OrganizationId, + InstallationId = InstallationId, + ProviderId = ProviderId, + CipherId = CipherId, + CollectionId = CollectionId, + PolicyId = PolicyId, + GroupId = GroupId, + OrganizationUserId = OrganizationUserId, + ProviderUserId = ProviderUserId, + ProviderOrganizationId = ProviderOrganizationId, + DeviceType = DeviceType.HasValue ? (int)DeviceType.Value : null, + IpAddress = IpAddress, + ActingUserId = ActingUserId, + SystemUser = SystemUser.HasValue ? (int)SystemUser.Value : null, + DomainName = DomainName, + SecretId = SecretId, + ServiceAccountId = ServiceAccountId + }; } public static List IndexEvent(EventMessage e) diff --git a/src/Core/Models/Data/InstallationDeviceEntity.cs b/src/Core/Models/Data/InstallationDeviceEntity.cs index cb7bf00873..3186efc661 100644 --- a/src/Core/Models/Data/InstallationDeviceEntity.cs +++ b/src/Core/Models/Data/InstallationDeviceEntity.cs @@ -1,8 +1,9 @@ -using Microsoft.Azure.Cosmos.Table; +using Azure; +using Azure.Data.Tables; namespace Bit.Core.Models.Data; -public class InstallationDeviceEntity : TableEntity +public class InstallationDeviceEntity : ITableEntity { public InstallationDeviceEntity() { } @@ -27,6 +28,11 @@ public class InstallationDeviceEntity : TableEntity RowKey = parts[1]; } + public string PartitionKey { get; set; } + public string RowKey { get; set; } + public DateTimeOffset? Timestamp { get; set; } + public ETag ETag { get; set; } + public static bool IsInstallationDeviceId(string deviceId) { return deviceId != null && deviceId.Length == 73 && deviceId[36] == '_'; diff --git a/src/Core/Repositories/TableStorage/EventRepository.cs b/src/Core/Repositories/TableStorage/EventRepository.cs index 7044850033..7c5cb97dba 100644 --- a/src/Core/Repositories/TableStorage/EventRepository.cs +++ b/src/Core/Repositories/TableStorage/EventRepository.cs @@ -1,14 +1,14 @@ -using Bit.Core.Models.Data; +using Azure.Data.Tables; +using Bit.Core.Models.Data; using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.Core.Vault.Entities; -using Microsoft.Azure.Cosmos.Table; namespace Bit.Core.Repositories.TableStorage; public class EventRepository : IEventRepository { - private readonly CloudTable _table; + private readonly TableClient _tableClient; public EventRepository(GlobalSettings globalSettings) : this(globalSettings.Events.ConnectionString) @@ -16,9 +16,8 @@ public class EventRepository : IEventRepository public EventRepository(string storageConnectionString) { - var storageAccount = CloudStorageAccount.Parse(storageConnectionString); - var tableClient = storageAccount.CreateCloudTableClient(); - _table = tableClient.GetTableReference("event"); + var tableClient = new TableServiceClient(storageConnectionString); + _tableClient = tableClient.GetTableClient("event"); } public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, @@ -76,7 +75,7 @@ public class EventRepository : IEventRepository throw new ArgumentException(nameof(e)); } - await CreateEntityAsync(entity); + await CreateEventAsync(entity); } public async Task CreateManyAsync(IEnumerable e) @@ -99,7 +98,7 @@ public class EventRepository : IEventRepository var groupEntities = group.ToList(); if (groupEntities.Count == 1) { - await CreateEntityAsync(groupEntities.First()); + await CreateEventAsync(groupEntities.First()); continue; } @@ -107,7 +106,7 @@ public class EventRepository : IEventRepository var iterations = groupEntities.Count / 100; for (var i = 0; i <= iterations; i++) { - var batch = new TableBatchOperation(); + var batch = new List(); var batchEntities = groupEntities.Skip(i * 100).Take(100); if (!batchEntities.Any()) { @@ -116,19 +115,15 @@ public class EventRepository : IEventRepository foreach (var entity in batchEntities) { - batch.InsertOrReplace(entity); + batch.Add(new TableTransactionAction(TableTransactionActionType.Add, + entity.ToAzureEvent())); } - await _table.ExecuteBatchAsync(batch); + await _tableClient.SubmitTransactionAsync(batch); } } } - public async Task CreateEntityAsync(ITableEntity entity) - { - await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); - } - public async Task> GetManyAsync(string partitionKey, string rowKey, DateTime startDate, DateTime endDate, PageOptions pageOptions) { @@ -136,60 +131,28 @@ public class EventRepository : IEventRepository var end = CoreHelpers.DateTimeToTableStorageKey(endDate); var filter = MakeFilter(partitionKey, string.Format(rowKey, start), string.Format(rowKey, end)); - var query = new TableQuery().Where(filter).Take(pageOptions.PageSize); var result = new PagedResult(); - var continuationToken = DeserializeContinuationToken(pageOptions?.ContinuationToken); + var query = _tableClient.QueryAsync(filter, pageOptions.PageSize); - var queryResults = await _table.ExecuteQuerySegmentedAsync(query, continuationToken); - result.ContinuationToken = SerializeContinuationToken(queryResults.ContinuationToken); - result.Data.AddRange(queryResults.Results); + await using (var enumerator = query.AsPages(pageOptions?.ContinuationToken, + pageOptions.PageSize).GetAsyncEnumerator()) + { + await enumerator.MoveNextAsync(); + + result.ContinuationToken = enumerator.Current.ContinuationToken; + result.Data.AddRange(enumerator.Current.Values.Select(e => e.ToEventTableEntity())); + } return result; } + private async Task CreateEventAsync(EventTableEntity entity) + { + await _tableClient.UpsertEntityAsync(entity.ToAzureEvent()); + } + private string MakeFilter(string partitionKey, string rowStart, string rowEnd) { - var rowFilter = TableQuery.CombineFilters( - TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.LessThanOrEqual, $"{rowStart}`"), - TableOperators.And, - TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.GreaterThanOrEqual, $"{rowEnd}_")); - - return TableQuery.CombineFilters( - TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, partitionKey), - TableOperators.And, - rowFilter); - } - - private string SerializeContinuationToken(TableContinuationToken token) - { - if (token == null) - { - return null; - } - - return string.Format("{0}__{1}__{2}__{3}", (int)token.TargetLocation, token.NextTableName, - token.NextPartitionKey, token.NextRowKey); - } - - private TableContinuationToken DeserializeContinuationToken(string token) - { - if (string.IsNullOrWhiteSpace(token)) - { - return null; - } - - var tokenParts = token.Split(new string[] { "__" }, StringSplitOptions.None); - if (tokenParts.Length < 4 || !Enum.TryParse(tokenParts[0], out StorageLocation tLoc)) - { - return null; - } - - return new TableContinuationToken - { - TargetLocation = tLoc, - NextTableName = string.IsNullOrWhiteSpace(tokenParts[1]) ? null : tokenParts[1], - NextPartitionKey = string.IsNullOrWhiteSpace(tokenParts[2]) ? null : tokenParts[2], - NextRowKey = string.IsNullOrWhiteSpace(tokenParts[3]) ? null : tokenParts[3] - }; + return $"PartitionKey eq '{partitionKey}' and RowKey le '{rowStart}' and RowKey ge '{rowEnd}'"; } } diff --git a/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs b/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs index 32b466d1b3..2dee07dc2b 100644 --- a/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs +++ b/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs @@ -1,13 +1,12 @@ -using System.Net; +using Azure.Data.Tables; using Bit.Core.Models.Data; using Bit.Core.Settings; -using Microsoft.Azure.Cosmos.Table; namespace Bit.Core.Repositories.TableStorage; public class InstallationDeviceRepository : IInstallationDeviceRepository { - private readonly CloudTable _table; + private readonly TableClient _tableClient; public InstallationDeviceRepository(GlobalSettings globalSettings) : this(globalSettings.Events.ConnectionString) @@ -15,14 +14,13 @@ public class InstallationDeviceRepository : IInstallationDeviceRepository public InstallationDeviceRepository(string storageConnectionString) { - var storageAccount = CloudStorageAccount.Parse(storageConnectionString); - var tableClient = storageAccount.CreateCloudTableClient(); - _table = tableClient.GetTableReference("installationdevice"); + var tableClient = new TableServiceClient(storageConnectionString); + _tableClient = tableClient.GetTableClient("installationdevice"); } public async Task UpsertAsync(InstallationDeviceEntity entity) { - await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); + await _tableClient.UpsertEntityAsync(entity); } public async Task UpsertManyAsync(IList entities) @@ -52,7 +50,7 @@ public class InstallationDeviceRepository : IInstallationDeviceRepository var iterations = groupEntities.Count / 100; for (var i = 0; i <= iterations; i++) { - var batch = new TableBatchOperation(); + var batch = new List(); var batchEntities = groupEntities.Skip(i * 100).Take(100); if (!batchEntities.Any()) { @@ -61,24 +59,16 @@ public class InstallationDeviceRepository : IInstallationDeviceRepository foreach (var entity in batchEntities) { - batch.InsertOrReplace(entity); + batch.Add(new TableTransactionAction(TableTransactionActionType.UpsertReplace, entity)); } - await _table.ExecuteBatchAsync(batch); + await _tableClient.SubmitTransactionAsync(batch); } } } public async Task DeleteAsync(InstallationDeviceEntity entity) { - try - { - entity.ETag = "*"; - await _table.ExecuteAsync(TableOperation.Delete(entity)); - } - catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound) - { - throw; - } + await _tableClient.DeleteEntityAsync(entity.PartitionKey, entity.RowKey); } } diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index f8f24cfbdb..e0d2e95dc9 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; @@ -28,6 +29,12 @@ public interface IPaymentService int newlyPurchasedAdditionalStorage, DateTime? prorationDate = null); Task AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats, DateTime? prorationDate = null); + Task AdjustSeats( + Provider provider, + Plan plan, + int currentlySubscribedSeats, + int newlySubscribedSeats, + DateTime? prorationDate = null); Task AdjustSmSeatsAsync(Organization organization, Plan plan, int additionalSeats, DateTime? prorationDate = null); Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId, DateTime? prorationDate = null); diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs index 073d5cdacd..908dc2c0d8 100644 --- a/src/Core/Services/IStripeAdapter.cs +++ b/src/Core/Services/IStripeAdapter.cs @@ -1,4 +1,5 @@ using Bit.Core.Models.BitStripe; +using Stripe; namespace Bit.Core.Services; @@ -16,6 +17,7 @@ public interface IStripeAdapter Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options); Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options); Task> InvoiceListAsync(StripeInvoiceListOptions options); + Task> InvoiceSearchAsync(InvoiceSearchOptions options); Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options); Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options); Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options); diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs index ef8d13aea8..a7109252d4 100644 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Services/Implementations/StripeAdapter.cs @@ -1,4 +1,5 @@ using Bit.Core.Models.BitStripe; +using Stripe; namespace Bit.Core.Services; @@ -103,6 +104,9 @@ public class StripeAdapter : IStripeAdapter return invoices; } + public async Task> InvoiceSearchAsync(InvoiceSearchOptions options) + => (await _invoiceService.SearchAsync(options)).Data; + public Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options) { return _invoiceService.UpdateAsync(id, options); diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 19437a1ee2..234543a8f6 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Constants; using Bit.Core.Entities; using Bit.Core.Enums; @@ -757,14 +758,14 @@ public class StripePaymentService : IPaymentService }).ToList(); } - private async Task FinalizeSubscriptionChangeAsync(IStorableSubscriber storableSubscriber, + private async Task FinalizeSubscriptionChangeAsync(ISubscriber subscriber, SubscriptionUpdate subscriptionUpdate, DateTime? prorationDate, bool invoiceNow = false) { // remember, when in doubt, throw var subGetOptions = new SubscriptionGetOptions(); // subGetOptions.AddExpand("customer"); subGetOptions.AddExpand("customer.tax"); - var sub = await _stripeAdapter.SubscriptionGetAsync(storableSubscriber.GatewaySubscriptionId, subGetOptions); + var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subGetOptions); if (sub == null) { throw new GatewayException("Subscription not found."); @@ -776,6 +777,7 @@ public class StripePaymentService : IPaymentService var chargeNow = collectionMethod == "charge_automatically"; var updatedItemOptions = subscriptionUpdate.UpgradeItemsOptions(sub); var isPm5864DollarThresholdEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5864DollarThreshold); + var isAnnualPlan = sub?.Items?.Data.FirstOrDefault()?.Plan?.Interval == "year"; var subUpdateOptions = new SubscriptionUpdateOptions { @@ -787,25 +789,10 @@ public class StripePaymentService : IPaymentService CollectionMethod = "send_invoice", ProrationDate = prorationDate, }; - var immediatelyInvoice = false; - if (!invoiceNow && isPm5864DollarThresholdEnabled && sub.Status.Trim() != "trialing") + if (!invoiceNow && isAnnualPlan && isPm5864DollarThresholdEnabled && sub.Status.Trim() != "trialing") { - var upcomingInvoiceWithChanges = await _stripeAdapter.InvoiceUpcomingAsync(new UpcomingInvoiceOptions - { - Customer = storableSubscriber.GatewayCustomerId, - Subscription = storableSubscriber.GatewaySubscriptionId, - SubscriptionItems = ToInvoiceSubscriptionItemOptions(updatedItemOptions), - SubscriptionProrationBehavior = Constants.CreateProrations, - SubscriptionProrationDate = prorationDate, - SubscriptionBillingCycleAnchor = SubscriptionBillingCycleAnchor.Now - }); - - var isAnnualPlan = sub?.Items?.Data.FirstOrDefault()?.Plan?.Interval == "year"; - immediatelyInvoice = isAnnualPlan && upcomingInvoiceWithChanges.AmountRemaining >= 50000; - - subUpdateOptions.BillingCycleAnchor = immediatelyInvoice - ? SubscriptionBillingCycleAnchor.Now - : SubscriptionBillingCycleAnchor.Unchanged; + subUpdateOptions.PendingInvoiceItemInterval = + new SubscriptionPendingInvoiceItemIntervalOptions { Interval = "month" }; } var pm5766AutomaticTaxIsEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax); @@ -858,21 +845,16 @@ public class StripePaymentService : IPaymentService { try { - if (!isPm5864DollarThresholdEnabled || immediatelyInvoice || invoiceNow) + if (chargeNow) { - if (chargeNow) - { - paymentIntentClientSecret = await PayInvoiceAfterSubscriptionChangeAsync(storableSubscriber, invoice); - } - else - { - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, new InvoiceFinalizeOptions - { - AutoAdvance = false, - }); - await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new InvoiceSendOptions()); - paymentIntentClientSecret = null; - } + paymentIntentClientSecret = await PayInvoiceAfterSubscriptionChangeAsync(subscriber, invoice); + } + else + { + invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, + new InvoiceFinalizeOptions { AutoAdvance = false, }); + await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new InvoiceSendOptions()); + paymentIntentClientSecret = null; } } catch @@ -943,6 +925,17 @@ public class StripePaymentService : IPaymentService return FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats), prorationDate); } + public Task AdjustSeats( + Provider provider, + StaticStore.Plan plan, + int currentlySubscribedSeats, + int newlySubscribedSeats, + DateTime? prorationDate = null) + => FinalizeSubscriptionChangeAsync( + provider, + new ProviderSubscriptionUpdate(plan.Type, currentlySubscribedSeats, newlySubscribedSeats), + prorationDate); + public Task AdjustSmSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats, DateTime? prorationDate = null) { return FinalizeSubscriptionChangeAsync(organization, new SmSeatSubscriptionUpdate(organization, plan, additionalSeats), prorationDate); @@ -1610,10 +1603,25 @@ public class StripePaymentService : IPaymentService return subscriptionInfo; } - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, new SubscriptionGetOptions + { + Expand = ["test_clock"] + }); + if (sub != null) { subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(sub); + + if (_featureService.IsEnabled(FeatureFlagKeys.AC1795_UpdatedSubscriptionStatusSection)) + { + var (suspensionDate, unpaidPeriodEndDate) = await GetSuspensionDateAsync(sub); + + if (suspensionDate.HasValue && unpaidPeriodEndDate.HasValue) + { + subscriptionInfo.Subscription.SuspensionDate = suspensionDate; + subscriptionInfo.Subscription.UnpaidPeriodEndDate = unpaidPeriodEndDate; + } + } } if (sub is { CanceledAt: not null } || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) @@ -1930,4 +1938,45 @@ public class StripePaymentService : IPaymentService ? subscriberName : subscriberName[..30]; } + + private async Task<(DateTime?, DateTime?)> GetSuspensionDateAsync(Subscription subscription) + { + if (subscription.Status is not "past_due" && subscription.Status is not "unpaid") + { + return (null, null); + } + + var openInvoices = await _stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + { + Query = $"subscription:'{subscription.Id}' status:'open'" + }); + + if (openInvoices.Count == 0) + { + return (null, null); + } + + var currentDate = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow; + + switch (subscription.CollectionMethod) + { + case "charge_automatically": + { + var firstOverdueInvoice = openInvoices + .Where(invoice => invoice.PeriodEnd < currentDate && invoice.Attempted) + .MinBy(invoice => invoice.Created); + + return (firstOverdueInvoice?.Created.AddDays(14), firstOverdueInvoice?.PeriodEnd); + } + case "send_invoice": + { + var firstOverdueInvoice = openInvoices + .Where(invoice => invoice.DueDate < currentDate) + .MinBy(invoice => invoice.Created); + + return (firstOverdueInvoice?.DueDate?.AddDays(30), firstOverdueInvoice?.PeriodEnd); + } + default: return (null, null); + } + } } diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index 84037a0a1c..50b4efe6fb 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -221,6 +221,8 @@ public class GlobalSettings : IGlobalSettings private string _connectionString; private string _readOnlyConnectionString; private string _jobSchedulerConnectionString; + public bool SkipDatabasePreparation { get; set; } + public bool DisableDatabaseMaintenanceJobs { get; set; } public string ConnectionString { diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs index 5d0becf7b4..af658a409a 100644 --- a/src/Core/Utilities/CoreHelpers.cs +++ b/src/Core/Utilities/CoreHelpers.cs @@ -32,6 +32,10 @@ public static class CoreHelpers private static readonly Random _random = new Random(); private static readonly string RealConnectingIp = "X-Connecting-IP"; private static readonly Regex _whiteSpaceRegex = new Regex(@"\s+"); + private static readonly JsonSerializerOptions _jsonSerializerOptions = new() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }; /// /// Generate sequential Guid for Sql Server. @@ -778,22 +782,12 @@ public static class CoreHelpers return new T(); } - var options = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }; - - return System.Text.Json.JsonSerializer.Deserialize(jsonData, options); + return System.Text.Json.JsonSerializer.Deserialize(jsonData, _jsonSerializerOptions); } public static string ClassToJsonData(T data) { - var options = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }; - - return System.Text.Json.JsonSerializer.Serialize(data, options); + return System.Text.Json.JsonSerializer.Serialize(data, _jsonSerializerOptions); } public static ICollection AddIfNotExists(this ICollection list, T item) diff --git a/src/Core/Utilities/StaticStore.cs b/src/Core/Utilities/StaticStore.cs index dcf63df138..007f3374e0 100644 --- a/src/Core/Utilities/StaticStore.cs +++ b/src/Core/Utilities/StaticStore.cs @@ -147,7 +147,6 @@ public static class StaticStore public static Plan GetPlan(PlanType planType) => Plans.SingleOrDefault(p => p.Type == planType); - public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) => SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType); diff --git a/src/Events/appsettings.json b/src/Events/appsettings.json index 101911bb0d..e72b978f2f 100644 --- a/src/Events/appsettings.json +++ b/src/Events/appsettings.json @@ -14,10 +14,6 @@ "events": { "connectionString": "SECRET" }, - "documentDb": { - "uri": "SECRET", - "key": "SECRET" - }, "sentry": { "dsn": "SECRET" }, diff --git a/src/EventsProcessor/AzureQueueHostedService.cs b/src/EventsProcessor/AzureQueueHostedService.cs index 03c0034539..b1b309b50f 100644 --- a/src/EventsProcessor/AzureQueueHostedService.cs +++ b/src/EventsProcessor/AzureQueueHostedService.cs @@ -30,6 +30,7 @@ public class AzureQueueHostedService : IHostedService, IDisposable _logger.LogInformation(Constants.BypassFiltersEventId, "Starting service."); _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; } @@ -39,8 +40,10 @@ public class AzureQueueHostedService : IHostedService, IDisposable { return; } + _logger.LogWarning("Stopping service."); - _cts.Cancel(); + + await _cts.CancelAsync(); await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); cancellationToken.ThrowIfCancellationRequested(); } @@ -64,13 +67,15 @@ public class AzureQueueHostedService : IHostedService, IDisposable { try { - var messages = await _queueClient.ReceiveMessagesAsync(32); + var messages = await _queueClient.ReceiveMessagesAsync(32, + cancellationToken: cancellationToken); if (messages.Value?.Any() ?? false) { foreach (var message in messages.Value) { await ProcessQueueMessageAsync(message.DecodeMessageText(), cancellationToken); - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt, + cancellationToken); } } else @@ -78,14 +83,15 @@ public class AzureQueueHostedService : IHostedService, IDisposable await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); } } - catch (Exception e) + catch (Exception ex) { - _logger.LogError(e, "Exception occurred: " + e.Message); + _logger.LogError(ex, "Error occurred processing message block."); + await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); } } - _logger.LogWarning("Done processing."); + _logger.LogWarning("Done processing messages."); } public async Task ProcessQueueMessageAsync(string message, CancellationToken cancellationToken) @@ -98,14 +104,14 @@ public class AzureQueueHostedService : IHostedService, IDisposable try { _logger.LogInformation("Processing message."); - var events = new List(); + var events = new List(); using var jsonDocument = JsonDocument.Parse(message); var root = jsonDocument.RootElement; if (root.ValueKind == JsonValueKind.Array) { var indexedEntities = root.Deserialize>() - .SelectMany(e => EventTableEntity.IndexEvent(e)); + .SelectMany(EventTableEntity.IndexEvent); events.AddRange(indexedEntities); } else if (root.ValueKind == JsonValueKind.Object) @@ -114,12 +120,15 @@ public class AzureQueueHostedService : IHostedService, IDisposable events.AddRange(EventTableEntity.IndexEvent(eventMessage)); } + cancellationToken.ThrowIfCancellationRequested(); + await _eventWriteService.CreateManyAsync(events); + _logger.LogInformation("Processed message."); } - catch (JsonException) + catch (JsonException ex) { - _logger.LogError("JsonReaderException: Unable to parse message."); + _logger.LogError(ex, "Unable to parse message."); } } } diff --git a/src/EventsProcessor/appsettings.json b/src/EventsProcessor/appsettings.json index af0ca259fa..c2c77bcb0d 100644 --- a/src/EventsProcessor/appsettings.json +++ b/src/EventsProcessor/appsettings.json @@ -2,10 +2,6 @@ "azureStorageConnectionString": "SECRET", "globalSettings": { "selfHosted": false, - "projectName": "Events Processor", - "documentDb": { - "uri": "SECRET", - "key": "SECRET" - } + "projectName": "Events Processor" } } diff --git a/src/Icons/appsettings.json b/src/Icons/appsettings.json index 65267ef4e9..6b4e2992e0 100644 --- a/src/Icons/appsettings.json +++ b/src/Icons/appsettings.json @@ -1,10 +1,6 @@ { "globalSettings": { - "projectName": "Icons", - "documentDb": { - "uri": "SECRET", - "key": "SECRET" - } + "projectName": "Icons" }, "iconsSettings": { "cacheEnabled": true, diff --git a/src/Identity/appsettings.json b/src/Identity/appsettings.json index e3626b4e16..16c3efe46b 100644 --- a/src/Identity/appsettings.json +++ b/src/Identity/appsettings.json @@ -27,10 +27,6 @@ "events": { "connectionString": "SECRET" }, - "documentDb": { - "uri": "SECRET", - "key": "SECRET" - }, "sentry": { "dsn": "SECRET" }, diff --git a/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs b/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs index 761545a255..f8448f4198 100644 --- a/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs +++ b/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs @@ -14,7 +14,7 @@ public class ProviderPlanRepository( globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString), IProviderPlanRepository { - public async Task GetByProviderId(Guid providerId) + public async Task> GetByProviderId(Guid providerId) { var sqlConnection = new SqlConnection(ConnectionString); @@ -23,6 +23,6 @@ public class ProviderPlanRepository( new { ProviderId = providerId }, commandType: CommandType.StoredProcedure); - return results.FirstOrDefault(); + return results.ToArray(); } } diff --git a/src/Infrastructure.Dapper/Infrastructure.Dapper.csproj b/src/Infrastructure.Dapper/Infrastructure.Dapper.csproj index 6c7ad57d19..046009ef73 100644 --- a/src/Infrastructure.Dapper/Infrastructure.Dapper.csproj +++ b/src/Infrastructure.Dapper/Infrastructure.Dapper.csproj @@ -5,7 +5,7 @@ - + diff --git a/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs b/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs index 2f9a707b27..386f7115d7 100644 --- a/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs +++ b/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs @@ -16,14 +16,17 @@ public class ProviderPlanRepository( mapper, context => context.ProviderPlans), IProviderPlanRepository { - public async Task GetByProviderId(Guid providerId) + public async Task> GetByProviderId(Guid providerId) { using var serviceScope = ServiceScopeFactory.CreateScope(); + var databaseContext = GetDatabaseContext(serviceScope); + var query = from providerPlan in databaseContext.ProviderPlans where providerPlan.ProviderId == providerId select providerPlan; - return await query.FirstOrDefaultAsync(); + + return await query.ToArrayAsync(); } } diff --git a/src/Notifications/appsettings.json b/src/Notifications/appsettings.json index 82355a0771..020d98cbd6 100644 --- a/src/Notifications/appsettings.json +++ b/src/Notifications/appsettings.json @@ -18,10 +18,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "documentDb": { - "uri": "SECRET", - "key": "SECRET" - }, "sentry": { "dsn": "SECRET" }, diff --git a/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs b/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs index 90a2335c22..f669e89eb0 100644 --- a/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs +++ b/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs @@ -64,4 +64,13 @@ public class ApiApplicationFactory : WebApplicationFactoryBase base.Dispose(disposing); SqliteConnection.Dispose(); } + + /// + /// Helper for logging in via client secret. + /// Currently used for Secrets Manager service accounts + /// + public async Task LoginWithClientSecretAsync(Guid clientId, string clientSecret) + { + return await _identityApplicationFactory.TokenFromAccessTokenAsync(clientId, clientSecret); + } } diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/AccessPoliciesControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/AccessPoliciesControllerTests.cs index b8eb4a7700..e1cce68704 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/AccessPoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/AccessPoliciesControllerTests.cs @@ -1,7 +1,7 @@ using System.Net; -using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.SecretsManager.Enums; +using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Api.Models.Response; using Bit.Api.SecretsManager.Models.Request; using Bit.Api.SecretsManager.Models.Response; @@ -28,6 +28,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); _projectRepository = _factory.GetService(); _groupRepository = _factory.GetService(); + _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -54,12 +56,6 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(serviceAccountId, result!.ServiceAccountAccessPolicies.First().ServiceAccountId); + Assert.Equal(serviceAccountId, result.ServiceAccountAccessPolicies.First().ServiceAccountId); Assert.True(result.ServiceAccountAccessPolicies.First().Read); Assert.True(result.ServiceAccountAccessPolicies.First().Write); AssertHelper.AssertRecent(result.ServiceAccountAccessPolicies.First().RevisionDate); @@ -168,7 +164,7 @@ public class AccessPoliciesControllerTests : IClassFixture { new UserProjectAccessPolicy @@ -249,13 +245,13 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(expectedRead, result!.Read); + Assert.Equal(expectedRead, result.Read); Assert.Equal(expectedWrite, result.Write); AssertHelper.AssertRecent(result.RevisionDate); var updatedAccessPolicy = await _accessPolicyRepository.GetByIdAsync(result.Id); Assert.NotNull(updatedAccessPolicy); - Assert.Equal(expectedRead, updatedAccessPolicy!.Read); + Assert.Equal(expectedRead, updatedAccessPolicy.Read); Assert.Equal(expectedWrite, updatedAccessPolicy.Write); AssertHelper.AssertRecent(updatedAccessPolicy.RevisionDate); } @@ -271,7 +267,7 @@ public class AccessPoliciesControllerTests : IClassFixture { new UserProjectAccessPolicy @@ -327,7 +323,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Empty(result!.UserAccessPolicies); + Assert.Empty(result.UserAccessPolicies); Assert.Empty(result.GroupAccessPolicies); Assert.Empty(result.ServiceAccountAccessPolicies); } @@ -357,7 +353,7 @@ public class AccessPoliciesControllerTests : IClassFixture { new UserProjectAccessPolicy @@ -409,7 +405,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result?.ServiceAccountAccessPolicies); - Assert.Single(result!.ServiceAccountAccessPolicies); + Assert.Single(result.ServiceAccountAccessPolicies); } [Theory] @@ -423,7 +419,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.NotEmpty(result!.Data); + Assert.NotEmpty(result.Data); } [Theory] @@ -467,7 +463,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.Empty(result!.Data); + Assert.Empty(result.Data); } [Theory] @@ -507,7 +503,7 @@ public class AccessPoliciesControllerTests : IClassFixture @@ -541,7 +537,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.NotEmpty(result!.Data); + Assert.NotEmpty(result.Data); Assert.Equal(serviceAccount.Id, result.Data.First(x => x.Id == serviceAccount.Id).Id); } @@ -556,7 +552,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.Empty(result!.Data); + Assert.Empty(result.Data); } [Theory] @@ -592,7 +588,7 @@ public class AccessPoliciesControllerTests : IClassFixture @@ -623,7 +619,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.NotEmpty(result!.Data); + Assert.NotEmpty(result.Data); Assert.Equal(project.Id, result.Data.First(x => x.Id == project.Id).Id); } @@ -638,7 +634,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result); - Assert.NotEmpty(result!.Data); + Assert.NotEmpty(result.Data); Assert.Equal(projectId, result.Data.First().GrantedProjectId); var createdAccessPolicy = await _accessPolicyRepository.GetByIdAsync(result.Data.First().Id); Assert.NotNull(createdAccessPolicy); - Assert.Equal(result.Data.First().Read, createdAccessPolicy!.Read); + Assert.Equal(result.Data.First().Read, createdAccessPolicy.Read); Assert.Equal(result.Data.First().Write, createdAccessPolicy.Write); Assert.Equal(result.Data.First().Id, createdAccessPolicy.Id); AssertHelper.AssertRecent(createdAccessPolicy.CreationDate); @@ -747,7 +743,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result); - Assert.Empty(result!.Data); + Assert.Empty(result.Data); } [Fact] @@ -782,7 +778,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result); - Assert.Empty(result!.Data); + Assert.Empty(result.Data); } [Theory] @@ -801,13 +797,13 @@ public class AccessPoliciesControllerTests : IClassFixture { new UserProjectAccessPolicy @@ -825,7 +821,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.NotEmpty(result!.Data); + Assert.NotEmpty(result.Data); Assert.Equal(initData.ServiceAccountId, result.Data.First().ServiceAccountId); Assert.NotNull(result.Data.First().ServiceAccountName); Assert.NotNull(result.Data.First().GrantedProjectName); @@ -842,7 +838,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Empty(result!.UserAccessPolicies); + Assert.Empty(result.UserAccessPolicies); Assert.Empty(result.GroupAccessPolicies); } @@ -881,7 +877,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result?.UserAccessPolicies); - Assert.Single(result!.UserAccessPolicies); + Assert.Single(result.UserAccessPolicies); } [Theory] @@ -924,7 +920,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Empty(result!.UserAccessPolicies); + Assert.Empty(result.UserAccessPolicies); Assert.Empty(result.GroupAccessPolicies); } @@ -1061,7 +1057,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result?.UserAccessPolicies); - Assert.Single(result!.UserAccessPolicies); + Assert.Single(result.UserAccessPolicies); } [Theory] @@ -1100,7 +1096,7 @@ public class AccessPoliciesControllerTests : IClassFixture { new UserProjectAccessPolicy @@ -1361,35 +1357,6 @@ public class AccessPoliciesControllerTests : IClassFixture SetupUserServiceAccountAccessPolicyRequestAsync( - PermissionType permissionType, Guid userId, Guid serviceAccountId) - { - if (permissionType == PermissionType.RunAsUserWithPermission) - { - var (email, newOrgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); - var accessPolicies = new List - { - new UserServiceAccountAccessPolicy - { - GrantedServiceAccountId = serviceAccountId, - OrganizationUserId = newOrgUser.Id, - Read = true, - Write = true, - }, - }; - await _accessPolicyRepository.CreateManyAsync(accessPolicies); - } - - return new AccessPoliciesCreateRequest - { - UserAccessPolicyRequests = new List - { - new() { GranteeId = userId, Read = true, Write = true }, - }, - }; - } - private class RequestSetupData { public Guid ProjectId { get; set; } diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/ProjectsControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/ProjectsControllerTests.cs index 523998ee28..95ddfd678e 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/ProjectsControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/ProjectsControllerTests.cs @@ -1,7 +1,7 @@ using System.Net; -using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.SecretsManager.Enums; +using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Api.Models.Response; using Bit.Api.SecretsManager.Models.Request; using Bit.Api.SecretsManager.Models.Response; @@ -10,7 +10,6 @@ using Bit.Core.Enums; using Bit.Core.SecretsManager.Entities; using Bit.Core.SecretsManager.Repositories; using Bit.Test.Common.Helpers; -using Pipelines.Sockets.Unofficial.Arenas; using Xunit; namespace Bit.Api.IntegrationTest.SecretsManager.Controllers; @@ -24,6 +23,7 @@ public class ProjectsControllerTests : IClassFixture, IAs private readonly ApiApplicationFactory _factory; private readonly IProjectRepository _projectRepository; private readonly IAccessPolicyRepository _accessPolicyRepository; + private readonly LoginHelper _loginHelper; private string _email = null!; private SecretsManagerOrganizationHelper _organizationHelper = null!; @@ -34,6 +34,7 @@ public class ProjectsControllerTests : IClassFixture, IAs _client = _factory.CreateClient(); _projectRepository = _factory.GetService(); _accessPolicyRepository = _factory.GetService(); + _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -49,12 +50,6 @@ public class ProjectsControllerTests : IClassFixture, IAs return Task.CompletedTask; } - private async Task LoginAsync(string email) - { - var tokens = await _factory.LoginAsync(email); - _client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); - } - [Theory] [InlineData(false, false, false)] [InlineData(false, false, true)] @@ -66,7 +61,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task ListByOrganization_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var response = await _client.GetAsync($"/organizations/{org.Id}/projects"); Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); @@ -77,7 +72,7 @@ public class ProjectsControllerTests : IClassFixture, IAs { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); await CreateProjectsAsync(org.Id); @@ -86,7 +81,7 @@ public class ProjectsControllerTests : IClassFixture, IAs var result = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(result); - Assert.Empty(result!.Data); + Assert.Empty(result.Data); } [Theory] @@ -101,7 +96,7 @@ public class ProjectsControllerTests : IClassFixture, IAs var result = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(result); - Assert.NotEmpty(result!.Data); + Assert.NotEmpty(result.Data); Assert.Equal(projectIds.Count, result.Data.Count()); } @@ -116,7 +111,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Create_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var request = new ProjectCreateRequestModel { Name = _mockEncryptedString }; @@ -129,7 +124,7 @@ public class ProjectsControllerTests : IClassFixture, IAs [InlineData(PermissionType.RunAsUserWithPermission)] public async Task Create_AtMaxProjects_BadRequest(PermissionType permissionType) { - var (_, organization) = await SetupProjectsWithAccessAsync(permissionType, 3); + var (_, organization) = await SetupProjectsWithAccessAsync(permissionType); var request = new ProjectCreateRequestModel { Name = _mockEncryptedString }; var response = await _client.PostAsJsonAsync($"/organizations/{organization.Id}/projects", request); @@ -143,14 +138,14 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Create_Success(PermissionType permissionType) { var (org, adminOrgUser) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var orgUserId = adminOrgUser.Id; var currentUserId = adminOrgUser.UserId!.Value; if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); orgUserId = orgUser.Id; currentUserId = orgUser.UserId!.Value; } @@ -162,7 +157,7 @@ public class ProjectsControllerTests : IClassFixture, IAs var result = await response.Content.ReadFromJsonAsync(); Assert.NotNull(result); - Assert.Equal(request.Name, result!.Name); + Assert.Equal(request.Name, result.Name); AssertHelper.AssertRecent(result.RevisionDate); AssertHelper.AssertRecent(result.CreationDate); @@ -196,7 +191,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Update_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var initialProject = await _projectRepository.CreateAsync(new Project { @@ -244,7 +239,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Update_NonExistingProject_NotFound() { await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var request = new ProjectUpdateRequestModel { @@ -262,7 +257,7 @@ public class ProjectsControllerTests : IClassFixture, IAs { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var project = await _projectRepository.CreateAsync(new Project { @@ -292,7 +287,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Get_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var project = await _projectRepository.CreateAsync(new Project { @@ -313,7 +308,7 @@ public class ProjectsControllerTests : IClassFixture, IAs { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var createdProject = await _projectRepository.CreateAsync(new Project { @@ -330,7 +325,7 @@ public class ProjectsControllerTests : IClassFixture, IAs { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var createdProject = await _projectRepository.CreateAsync(new Project { @@ -338,7 +333,7 @@ public class ProjectsControllerTests : IClassFixture, IAs Name = _mockEncryptedString, }); - var deleteResponse = await _client.PostAsync("/projects/delete", JsonContent.Create(createdProject.Id)); + await _client.PostAsync("/projects/delete", JsonContent.Create(createdProject.Id)); var response = await _client.GetAsync($"/projects/{createdProject.Id}"); Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); @@ -372,7 +367,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Delete_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var projectIds = await CreateProjectsAsync(org.Id); @@ -385,7 +380,7 @@ public class ProjectsControllerTests : IClassFixture, IAs { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var projectIds = await CreateProjectsAsync(org.Id); @@ -394,7 +389,7 @@ public class ProjectsControllerTests : IClassFixture, IAs var results = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(results); Assert.Equal(projectIds.OrderBy(x => x), - results!.Data.Select(x => x.Id).OrderBy(x => x)); + results.Data.Select(x => x.Id).OrderBy(x => x)); Assert.All(results.Data, item => Assert.Equal("access denied", item.Error)); } @@ -411,7 +406,7 @@ public class ProjectsControllerTests : IClassFixture, IAs var results = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(results); Assert.Equal(projectIds.OrderBy(x => x), - results!.Data.Select(x => x.Id).OrderBy(x => x)); + results.Data.Select(x => x.Id).OrderBy(x => x)); Assert.DoesNotContain(results.Data, x => x.Error != null); var projects = await _projectRepository.GetManyWithSecretsByIds(projectIds); @@ -438,7 +433,7 @@ public class ProjectsControllerTests : IClassFixture, IAs int projectsToCreate = 3) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var projectIds = await CreateProjectsAsync(org.Id, projectsToCreate); if (permissionType == PermissionType.RunAsAdmin) @@ -447,7 +442,7 @@ public class ProjectsControllerTests : IClassFixture, IAs } var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var accessPolicies = projectIds.Select(projectId => new UserProjectAccessPolicy { @@ -467,7 +462,7 @@ public class ProjectsControllerTests : IClassFixture, IAs private async Task SetupProjectWithAccessAsync(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var initialProject = await _projectRepository.CreateAsync(new Project { @@ -481,7 +476,7 @@ public class ProjectsControllerTests : IClassFixture, IAs } var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var accessPolicies = new List { diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsControllerTests.cs index 4932ad9b9b..0ff7396eda 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsControllerTests.cs @@ -1,7 +1,7 @@ using System.Net; -using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.SecretsManager.Enums; +using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Api.Models.Response; using Bit.Api.SecretsManager.Models.Request; using Bit.Api.SecretsManager.Models.Response; @@ -23,6 +23,7 @@ public class SecretsControllerTests : IClassFixture, IAsy private readonly ISecretRepository _secretRepository; private readonly IProjectRepository _projectRepository; private readonly IAccessPolicyRepository _accessPolicyRepository; + private readonly LoginHelper _loginHelper; private string _email = null!; private SecretsManagerOrganizationHelper _organizationHelper = null!; @@ -34,6 +35,7 @@ public class SecretsControllerTests : IClassFixture, IAsy _secretRepository = _factory.GetService(); _projectRepository = _factory.GetService(); _accessPolicyRepository = _factory.GetService(); + _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -49,12 +51,6 @@ public class SecretsControllerTests : IClassFixture, IAsy return Task.CompletedTask; } - private async Task LoginAsync(string email) - { - var tokens = await _factory.LoginAsync(email); - _client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); - } - [Theory] [InlineData(false, false, false)] [InlineData(false, false, true)] @@ -66,7 +62,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task ListByOrganization_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var response = await _client.GetAsync($"/organizations/{org.Id}/secrets"); Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); @@ -77,8 +73,8 @@ public class SecretsControllerTests : IClassFixture, IAsy [InlineData(PermissionType.RunAsUserWithPermission)] public async Task ListByOrganization_Success(PermissionType permissionType) { - var (org, orgUserOwner) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); var project = await _projectRepository.CreateAsync(new Project { @@ -90,7 +86,7 @@ public class SecretsControllerTests : IClassFixture, IAsy if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var accessPolicies = new List { @@ -122,7 +118,7 @@ public class SecretsControllerTests : IClassFixture, IAsy var result = await response.Content.ReadFromJsonAsync(); Assert.NotNull(result); - Assert.NotEmpty(result!.Secrets); + Assert.NotEmpty(result.Secrets); Assert.Equal(secretIds.Count, result.Secrets.Count()); } @@ -137,7 +133,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task Create_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var request = new SecretCreateRequestModel { @@ -154,7 +150,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task CreateWithoutProject_RunAsAdmin_Success() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var request = new SecretCreateRequestModel { @@ -168,7 +164,7 @@ public class SecretsControllerTests : IClassFixture, IAsy var result = await response.Content.ReadFromJsonAsync(); Assert.NotNull(result); - Assert.Equal(request.Key, result!.Key); + Assert.Equal(request.Key, result.Key); Assert.Equal(request.Value, result.Value); Assert.Equal(request.Note, result.Note); AssertHelper.AssertRecent(result.RevisionDate); @@ -188,7 +184,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task CreateWithDifferentProjectOrgId_RunAsAdmin_NotFound() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var anotherOrg = await _organizationHelper.CreateSmOrganizationAsync(); var project = @@ -210,7 +206,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task CreateWithMultipleProjects_RunAsAdmin_BadRequest() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var projectA = await _projectRepository.CreateAsync(new Project { OrganizationId = org.Id, Name = "123A" }); var projectB = await _projectRepository.CreateAsync(new Project { OrganizationId = org.Id, Name = "123B" }); @@ -231,8 +227,8 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task CreateWithoutProject_RunAsUser_NotFound() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); + await _loginHelper.LoginAsync(email); var request = new SecretCreateRequestModel { @@ -251,9 +247,9 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task CreateWithProject_Success(PermissionType permissionType) { var (org, orgAdminUser) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); - AccessClientType accessType = AccessClientType.NoAccessCheck; + var accessType = AccessClientType.NoAccessCheck; var project = await _projectRepository.CreateAsync(new Project() { @@ -267,12 +263,12 @@ public class SecretsControllerTests : IClassFixture, IAsy if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); accessType = AccessClientType.User; var accessPolicies = new List { - new Core.SecretsManager.Entities.UserProjectAccessPolicy + new UserProjectAccessPolicy { GrantedProjectId = project.Id, OrganizationUserId = orgUser.Id , Read = true, Write = true, }, @@ -296,7 +292,7 @@ public class SecretsControllerTests : IClassFixture, IAsy var secret = result.Secret; Assert.NotNull(secretResult); - Assert.Equal(secret.Id, secretResult!.Id); + Assert.Equal(secret.Id, secretResult.Id); Assert.Equal(secret.OrganizationId, secretResult.OrganizationId); Assert.Equal(secret.Key, secretResult.Key); Assert.Equal(secret.Value, secretResult.Value); @@ -316,7 +312,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task Get_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -336,7 +332,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task Get_Success(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var project = await _projectRepository.CreateAsync(new Project() { @@ -348,7 +344,7 @@ public class SecretsControllerTests : IClassFixture, IAsy if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var accessPolicies = new List { @@ -361,8 +357,8 @@ public class SecretsControllerTests : IClassFixture, IAsy } else { - var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.Admin, true); - await LoginAsync(email); + var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.Admin, true); + await _loginHelper.LoginAsync(email); } var secret = await _secretRepository.CreateAsync(new Secret @@ -395,7 +391,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task GetSecretsByProject_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var project = await _projectRepository.CreateAsync(new Project { @@ -411,8 +407,8 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task GetSecretsByProject_UserWithNoPermission_EmptyList() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); + await _loginHelper.LoginAsync(email); var project = await _projectRepository.CreateAsync(new Project() { @@ -421,7 +417,7 @@ public class SecretsControllerTests : IClassFixture, IAsy Name = _mockEncryptedString }); - var secret = await _secretRepository.CreateAsync(new Secret + await _secretRepository.CreateAsync(new Secret { OrganizationId = org.Id, Key = _mockEncryptedString, @@ -434,8 +430,8 @@ public class SecretsControllerTests : IClassFixture, IAsy response.EnsureSuccessStatusCode(); var result = await response.Content.ReadFromJsonAsync(); Assert.NotNull(result); - Assert.Empty(result!.Secrets); - Assert.Empty(result!.Projects); + Assert.Empty(result.Secrets); + Assert.Empty(result.Projects); } [Theory] @@ -444,7 +440,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task GetSecretsByProject_Success(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var project = await _projectRepository.CreateAsync(new Project() { @@ -456,7 +452,7 @@ public class SecretsControllerTests : IClassFixture, IAsy if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var accessPolicies = new List { @@ -501,7 +497,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task Update_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -525,32 +521,18 @@ public class SecretsControllerTests : IClassFixture, IAsy [Theory] [InlineData(PermissionType.RunAsAdmin)] [InlineData(PermissionType.RunAsUserWithPermission)] + [InlineData(PermissionType.RunAsServiceAccountWithPermission)] public async Task Update_Success(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); - var project = await _projectRepository.CreateAsync(new Project() { - Id = new Guid(), + Id = Guid.NewGuid(), OrganizationId = org.Id, Name = _mockEncryptedString }); - if (permissionType == PermissionType.RunAsUserWithPermission) - { - var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); - - var accessPolicies = new List - { - new UserProjectAccessPolicy - { - GrantedProjectId = project.Id, OrganizationUserId = orgUser.Id, Read = true, Write = true, - }, - }; - await _accessPolicyRepository.CreateManyAsync(accessPolicies); - } + await SetupProjectPermissionAndLoginAsync(permissionType, project); var secret = await _secretRepository.CreateAsync(new Secret { @@ -558,7 +540,7 @@ public class SecretsControllerTests : IClassFixture, IAsy Key = _mockEncryptedString, Value = _mockEncryptedString, Note = _mockEncryptedString, - Projects = permissionType == PermissionType.RunAsUserWithPermission ? new List() { project } : null + Projects = permissionType != PermissionType.RunAsAdmin ? new List() { project } : null }); var request = new SecretUpdateRequestModel() @@ -566,7 +548,7 @@ public class SecretsControllerTests : IClassFixture, IAsy Key = _mockEncryptedString, Value = "2.3Uk+WNBIoU5xzmVFNcoWzz==|1MsPIYuRfdOHfu/0uY6H2Q==|/98xy4wb6pHP1VTZ9JcNCYgQjEUMFPlqJgCwRk1YXKg=", Note = _mockEncryptedString, - ProjectIds = permissionType == PermissionType.RunAsUserWithPermission ? new Guid[] { project.Id } : null + ProjectIds = permissionType != PermissionType.RunAsAdmin ? new Guid[] { project.Id } : null }; var response = await _client.PutAsJsonAsync($"/secrets/{secret.Id}", request); @@ -595,7 +577,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task UpdateWithDifferentProjectOrgId_RunAsAdmin_NotFound() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var anotherOrg = await _organizationHelper.CreateSmOrganizationAsync(); var project = await _projectRepository.CreateAsync(new Project { Name = "123", OrganizationId = anotherOrg.Id }); @@ -624,7 +606,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task UpdateWithMultipleProjects_BadRequest() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var projectA = await _projectRepository.CreateAsync(new Project { OrganizationId = org.Id, Name = "123A" }); var projectB = await _projectRepository.CreateAsync(new Project { OrganizationId = org.Id, Name = "123B" }); @@ -660,7 +642,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task Delete_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -680,33 +662,34 @@ public class SecretsControllerTests : IClassFixture, IAsy { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); - var (_, secretIds) = await CreateSecretsAsync(org.Id, 3); + var (_, secretIds) = await CreateSecretsAsync(org.Id); var response = await _client.PostAsync("/secrets/delete", JsonContent.Create(secretIds)); var results = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(results); Assert.Equal(secretIds.OrderBy(x => x), - results!.Data.Select(x => x.Id).OrderBy(x => x)); + results.Data.Select(x => x.Id).OrderBy(x => x)); Assert.All(results.Data, item => Assert.Equal("access denied", item.Error)); } [Theory] [InlineData(PermissionType.RunAsAdmin)] [InlineData(PermissionType.RunAsUserWithPermission)] + [InlineData(PermissionType.RunAsServiceAccountWithPermission)] public async Task Delete_Success(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var (project, secretIds) = await CreateSecretsAsync(org.Id); if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var accessPolicies = new List { @@ -723,8 +706,8 @@ public class SecretsControllerTests : IClassFixture, IAsy var results = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(results?.Data); - Assert.Equal(secretIds.Count, results!.Data.Count()); - foreach (var result in results!.Data) + Assert.Equal(secretIds.Count, results.Data.Count()); + foreach (var result in results.Data) { Assert.Contains(result.Id, secretIds); Assert.Null(result.Error); @@ -745,7 +728,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task GetSecretsByIds_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -767,14 +750,14 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task GetSecretsByIds_Success(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var (project, secretIds) = await CreateSecretsAsync(org.Id); if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var accessPolicies = new List { @@ -788,7 +771,7 @@ public class SecretsControllerTests : IClassFixture, IAsy else { var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.Admin, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); } var request = new GetSecretsRequestModel { Ids = secretIds }; @@ -797,8 +780,8 @@ public class SecretsControllerTests : IClassFixture, IAsy response.EnsureSuccessStatusCode(); var result = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(result); - Assert.NotEmpty(result!.Data); - Assert.Equal(secretIds.Count, result!.Data.Count()); + Assert.NotEmpty(result.Data); + Assert.Equal(secretIds.Count, result.Data.Count()); } private async Task<(Project Project, List secretIds)> CreateSecretsAsync(Guid orgId, int numberToCreate = 3) @@ -826,4 +809,48 @@ public class SecretsControllerTests : IClassFixture, IAsy return (project, secretIds); } + + private async Task SetupProjectPermissionAndLoginAsync(PermissionType permissionType, Project project) + { + switch (permissionType) + { + case PermissionType.RunAsAdmin: + { + await _loginHelper.LoginAsync(_email); + break; + } + case PermissionType.RunAsUserWithPermission: + { + var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); + await _loginHelper.LoginAsync(email); + + var accessPolicies = new List + { + new UserProjectAccessPolicy + { + GrantedProjectId = project.Id, OrganizationUserId = orgUser.Id, Read = true, Write = true, + }, + }; + await _accessPolicyRepository.CreateManyAsync(accessPolicies); + break; + } + case PermissionType.RunAsServiceAccountWithPermission: + { + var apiKeyDetails = await _organizationHelper.CreateNewServiceAccountApiKeyAsync(); + await _loginHelper.LoginWithApiKeyAsync(apiKeyDetails); + + var accessPolicies = new List + { + new ServiceAccountProjectAccessPolicy + { + GrantedProjectId = project.Id, ServiceAccountId = apiKeyDetails.ApiKey.ServiceAccountId, Read = true, Write = true, + }, + }; + await _accessPolicyRepository.CreateManyAsync(accessPolicies); + break; + } + default: + throw new ArgumentOutOfRangeException(nameof(permissionType), permissionType, null); + } + } } diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerEventsControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerEventsControllerTests.cs index 4c053c3a2e..036e307d39 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerEventsControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerEventsControllerTests.cs @@ -1,6 +1,7 @@ using System.Net; using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Core.SecretsManager.Entities; using Bit.Core.SecretsManager.Repositories; using Xunit; diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerPortingControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerPortingControllerTests.cs index c57ceb20d9..ba41c1e862 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerPortingControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerPortingControllerTests.cs @@ -1,8 +1,7 @@ using System.Net; -using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Api.SecretsManager.Models.Request; -using Bit.Core.SecretsManager.Repositories; using Xunit; namespace Bit.Api.IntegrationTest.SecretsManager.Controllers; @@ -11,8 +10,7 @@ public class SecretsManagerPortingControllerTests : IClassFixture(); - _accessPolicyRepository = _factory.GetService(); + _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -38,12 +35,6 @@ public class SecretsManagerPortingControllerTests : IClassFixture(); var secretsList = new List(); @@ -76,7 +67,7 @@ public class SecretsManagerPortingControllerTests : IClassFixture, private readonly HttpClient _client; private readonly ApiApplicationFactory _factory; private readonly ISecretRepository _secretRepository; + private readonly LoginHelper _loginHelper; private string _email = null!; private SecretsManagerOrganizationHelper _organizationHelper = null!; @@ -26,6 +27,7 @@ public class SecretsTrashControllerTests : IClassFixture, _factory = factory; _client = _factory.CreateClient(); _secretRepository = _factory.GetService(); + _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -41,12 +43,6 @@ public class SecretsTrashControllerTests : IClassFixture, return Task.CompletedTask; } - private async Task LoginAsync(string email) - { - var tokens = await _factory.LoginAsync(email); - _client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); - } - [Theory] [InlineData(false, false, false)] [InlineData(false, false, true)] @@ -58,7 +54,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task ListByOrganization_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var response = await _client.GetAsync($"/secrets/{org.Id}/trash"); Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); @@ -69,7 +65,7 @@ public class SecretsTrashControllerTests : IClassFixture, { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var response = await _client.GetAsync($"/secrets/{org.Id}/trash"); Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); @@ -79,7 +75,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task ListByOrganization_Success() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); await _secretRepository.CreateAsync(new Secret { @@ -114,7 +110,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Empty_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var ids = new List { Guid.NewGuid() }; var response = await _client.PostAsJsonAsync($"/secrets/{org.Id}/trash/empty", ids); @@ -126,7 +122,7 @@ public class SecretsTrashControllerTests : IClassFixture, { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var ids = new List { Guid.NewGuid() }; var response = await _client.PostAsJsonAsync($"/secrets/{org.Id}/trash/empty", ids); @@ -137,7 +133,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Empty_Invalid_NotFound() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -155,7 +151,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Empty_Success() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -181,7 +177,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Restore_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var ids = new List { Guid.NewGuid() }; var response = await _client.PostAsJsonAsync($"/secrets/{org.Id}/trash/restore", ids); @@ -193,7 +189,7 @@ public class SecretsTrashControllerTests : IClassFixture, { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await LoginAsync(email); + await _loginHelper.LoginAsync(email); var ids = new List { Guid.NewGuid() }; var response = await _client.PostAsJsonAsync($"/secrets/{org.Id}/trash/restore", ids); @@ -204,7 +200,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Restore_Invalid_NotFound() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -222,7 +218,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Restore_Success() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/ServiceAccountsControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/ServiceAccountsControllerTests.cs index a482d9b04e..f25005b269 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/ServiceAccountsControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/ServiceAccountsControllerTests.cs @@ -1,7 +1,7 @@ using System.Net; -using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.SecretsManager.Enums; +using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Api.Models.Response; using Bit.Api.SecretsManager.Models.Request; using Bit.Api.SecretsManager.Models.Response; @@ -24,6 +24,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); _accessPolicyRepository = _factory.GetService(); _apiKeyRepository = _factory.GetService(); + _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -54,12 +56,6 @@ public class ServiceAccountsControllerTests : IClassFixture>(); Assert.NotNull(result); - Assert.NotEmpty(result!.Data); + Assert.NotEmpty(result.Data); Assert.Equal(serviceAccountIds.Count, result.Data.Count()); } @@ -99,7 +95,7 @@ public class ServiceAccountsControllerTests : IClassFixture>(); Assert.NotNull(result); - Assert.NotEmpty(result!.Data); + Assert.NotEmpty(result.Data); Assert.Equal(2, result.Data.Count()); } @@ -135,7 +131,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(serviceAccount.Id, result!.Id); + Assert.Equal(serviceAccount.Id, result.Id); Assert.Equal(serviceAccount.OrganizationId, result.OrganizationId); Assert.Equal(serviceAccount.Name, result.Name); Assert.Equal(serviceAccount.CreationDate, result.CreationDate); @@ -203,7 +199,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(request.Name, result!.Name); + Assert.Equal(request.Name, result.Name); AssertHelper.AssertRecent(result.RevisionDate); AssertHelper.AssertRecent(result.CreationDate); @@ -270,7 +266,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(request.Name, result!.Name); + Assert.Equal(request.Name, result.Name); Assert.NotEqual(initialServiceAccount.Name, result.Name); AssertHelper.AssertRecent(result.RevisionDate); Assert.NotEqual(initialServiceAccount.RevisionDate, result.RevisionDate); @@ -353,7 +349,7 @@ public class ServiceAccountsControllerTests : IClassFixture { new UserServiceAccountAccessPolicy @@ -443,7 +439,7 @@ public class ServiceAccountsControllerTests : IClassFixture { new UserServiceAccountAccessPolicy @@ -540,7 +536,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(request.Name, result!.Name); + Assert.Equal(request.Name, result.Name); Assert.NotNull(result.ClientSecret); Assert.Equal(mockExpiresAt, result.ExpireAt); AssertHelper.AssertRecent(result.RevisionDate); @@ -599,7 +595,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(request.Name, result!.Name); + Assert.Equal(request.Name, result.Name); Assert.NotNull(result.ClientSecret); Assert.Equal(mockExpiresAt, result.ExpireAt); AssertHelper.AssertRecent(result.RevisionDate); @@ -635,7 +631,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(request.Name, result!.Name); + Assert.Equal(request.Name, result.Name); Assert.NotNull(result.ClientSecret); Assert.Null(result.ExpireAt); AssertHelper.AssertRecent(result.RevisionDate); @@ -699,7 +695,7 @@ public class ServiceAccountsControllerTests : IClassFixture { new UserServiceAccountAccessPolicy @@ -847,7 +843,7 @@ public class ServiceAccountsControllerTests : IClassFixture SetupServiceAccountWithAccessAsync(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await LoginAsync(_email); + await _loginHelper.LoginAsync(_email); var initialServiceAccount = await _serviceAccountRepository.CreateAsync(new ServiceAccount { @@ -861,7 +857,7 @@ public class ServiceAccountsControllerTests : IClassFixture { diff --git a/test/Api.IntegrationTest/SecretsManager/Enums/PermissionType.cs b/test/Api.IntegrationTest/SecretsManager/Enums/PermissionType.cs index 7f1c4d7b99..972bc7f0be 100644 --- a/test/Api.IntegrationTest/SecretsManager/Enums/PermissionType.cs +++ b/test/Api.IntegrationTest/SecretsManager/Enums/PermissionType.cs @@ -4,4 +4,5 @@ public enum PermissionType { RunAsAdmin, RunAsUserWithPermission, + RunAsServiceAccountWithPermission, } diff --git a/test/Api.IntegrationTest/SecretsManager/Helpers/LoginHelper.cs b/test/Api.IntegrationTest/SecretsManager/Helpers/LoginHelper.cs new file mode 100644 index 0000000000..9de66bc11e --- /dev/null +++ b/test/Api.IntegrationTest/SecretsManager/Helpers/LoginHelper.cs @@ -0,0 +1,30 @@ +using System.Net.Http.Headers; +using Bit.Api.IntegrationTest.Factories; +using Bit.Core.SecretsManager.Models.Data; + +namespace Bit.Api.IntegrationTest.SecretsManager.Helpers; + +public class LoginHelper +{ + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + + public LoginHelper(ApiApplicationFactory factory, HttpClient client) + { + _factory = factory; + _client = client; + } + + public async Task LoginAsync(string email) + { + var tokens = await _factory.LoginAsync(email); + _client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); + } + + public async Task LoginWithApiKeyAsync(ApiKeyClientSecretDetails apiKeyDetails) + { + var token = await _factory.LoginWithClientSecretAsync(apiKeyDetails.ApiKey.Id, apiKeyDetails.ClientSecret); + _client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token); + _client.DefaultRequestHeaders.Add("service_account_id", apiKeyDetails.ApiKey.ServiceAccountId.ToString()); + } +} diff --git a/test/Api.IntegrationTest/SecretsManager/SecretsManagerOrganizationHelper.cs b/test/Api.IntegrationTest/SecretsManager/Helpers/SecretsManagerOrganizationHelper.cs similarity index 58% rename from test/Api.IntegrationTest/SecretsManager/SecretsManagerOrganizationHelper.cs rename to test/Api.IntegrationTest/SecretsManager/Helpers/SecretsManagerOrganizationHelper.cs index fea05de311..d2d03d979e 100644 --- a/test/Api.IntegrationTest/SecretsManager/SecretsManagerOrganizationHelper.cs +++ b/test/Api.IntegrationTest/SecretsManager/Helpers/SecretsManagerOrganizationHelper.cs @@ -4,8 +4,12 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; +using Bit.Core.SecretsManager.Commands.AccessTokens.Interfaces; +using Bit.Core.SecretsManager.Entities; +using Bit.Core.SecretsManager.Models.Data; +using Bit.Core.SecretsManager.Repositories; -namespace Bit.Api.IntegrationTest.SecretsManager; +namespace Bit.Api.IntegrationTest.SecretsManager.Helpers; public class SecretsManagerOrganizationHelper { @@ -13,17 +17,20 @@ public class SecretsManagerOrganizationHelper private readonly string _ownerEmail; private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IServiceAccountRepository _serviceAccountRepository; + private readonly ICreateAccessTokenCommand _createAccessTokenCommand; - public Organization _organization = null!; - public OrganizationUser _owner = null!; + private Organization _organization = null!; + private OrganizationUser _owner = null!; public SecretsManagerOrganizationHelper(ApiApplicationFactory factory, string ownerEmail) { _factory = factory; _organizationRepository = factory.GetService(); _organizationUserRepository = factory.GetService(); - _ownerEmail = ownerEmail; + _serviceAccountRepository = factory.GetService(); + _createAccessTokenCommand = factory.GetService(); } public async Task<(Organization organization, OrganizationUser owner)> Initialize(bool useSecrets, bool ownerAccessSecrets, bool organizationEnabled) @@ -58,8 +65,7 @@ public class SecretsManagerOrganizationHelper { var email = $"integration-test{Guid.NewGuid()}@bitwarden.com"; await _factory.LoginWithNewAccount(email); - var (organization, owner) = - await OrganizationTestHelpers.SignUpAsync(_factory, ownerEmail: email, billingEmail: email); + var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, ownerEmail: email, billingEmail: email); return organization; } @@ -71,4 +77,29 @@ public class SecretsManagerOrganizationHelper return (email, orgUser); } + + public async Task CreateNewServiceAccountApiKeyAsync() + { + var serviceAccountId = Guid.NewGuid(); + var serviceAccount = new ServiceAccount + { + Id = serviceAccountId, + OrganizationId = _organization.Id, + Name = $"integration-test-{serviceAccountId}sa", + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow + }; + await _serviceAccountRepository.CreateAsync(serviceAccount); + + var apiKey = new ApiKey + { + ServiceAccountId = serviceAccountId, + Name = "integration-token", + Key = Guid.NewGuid().ToString(), + ExpireAt = null, + Scope = "[\"api.secrets\"]", + EncryptedPayload = Guid.NewGuid().ToString() + }; + return await _createAccessTokenCommand.CreateAsync(apiKey); + } } diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs index fdbcc17e46..9d3c7ebfe5 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -56,7 +56,7 @@ public class OrganizationsControllerTests : IDisposable private readonly IAddSecretsManagerSubscriptionCommand _addSecretsManagerSubscriptionCommand; private readonly IPushNotificationService _pushNotificationService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly IGetSubscriptionQuery _getSubscriptionQuery; + private readonly ISubscriberQueries _subscriberQueries; private readonly IReferenceEventService _referenceEventService; private readonly IOrganizationEnableCollectionEnhancementsCommand _organizationEnableCollectionEnhancementsCommand; @@ -86,7 +86,7 @@ public class OrganizationsControllerTests : IDisposable _addSecretsManagerSubscriptionCommand = Substitute.For(); _pushNotificationService = Substitute.For(); _cancelSubscriptionCommand = Substitute.For(); - _getSubscriptionQuery = Substitute.For(); + _subscriberQueries = Substitute.For(); _referenceEventService = Substitute.For(); _organizationEnableCollectionEnhancementsCommand = Substitute.For(); @@ -113,7 +113,7 @@ public class OrganizationsControllerTests : IDisposable _addSecretsManagerSubscriptionCommand, _pushNotificationService, _cancelSubscriptionCommand, - _getSubscriptionQuery, + _subscriberQueries, _referenceEventService, _organizationEnableCollectionEnhancementsCommand); } diff --git a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs index 79aa2ca13d..4af60689c3 100644 --- a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs @@ -57,7 +57,7 @@ public class AccountsControllerTests : IDisposable private readonly IRotateUserKeyCommand _rotateUserKeyCommand; private readonly IFeatureService _featureService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly IGetSubscriptionQuery _getSubscriptionQuery; + private readonly ISubscriberQueries _subscriberQueries; private readonly IReferenceEventService _referenceEventService; private readonly ICurrentContext _currentContext; @@ -90,7 +90,7 @@ public class AccountsControllerTests : IDisposable _rotateUserKeyCommand = Substitute.For(); _featureService = Substitute.For(); _cancelSubscriptionCommand = Substitute.For(); - _getSubscriptionQuery = Substitute.For(); + _subscriberQueries = Substitute.For(); _referenceEventService = Substitute.For(); _currentContext = Substitute.For(); _cipherValidator = @@ -122,7 +122,7 @@ public class AccountsControllerTests : IDisposable _rotateUserKeyCommand, _featureService, _cancelSubscriptionCommand, - _getSubscriptionQuery, + _subscriberQueries, _referenceEventService, _currentContext, _cipherValidator, diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs new file mode 100644 index 0000000000..57480ac116 --- /dev/null +++ b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs @@ -0,0 +1,130 @@ +using Bit.Api.Billing.Controllers; +using Bit.Api.Billing.Models; +using Bit.Core; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Queries; +using Bit.Core.Context; +using Bit.Core.Enums; +using Bit.Core.Services; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Http.HttpResults; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; + +namespace Bit.Api.Test.Billing.Controllers; + +[ControllerCustomize(typeof(ProviderBillingController))] +[SutProviderCustomize] +public class ProviderBillingControllerTests +{ + [Theory, BitAutoData] + public async Task GetSubscriptionAsync_FFDisabled_NotFound( + Guid providerId, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(false); + + var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized( + Guid providerId, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(false); + + var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionAsync_NoSubscriptionData_NotFound( + Guid providerId, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(true); + + sutProvider.GetDependency().GetSubscriptionData(providerId).ReturnsNull(); + + var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionAsync_OK( + Guid providerId, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(true); + + var configuredPlans = new List + { + new (Guid.NewGuid(), providerId, PlanType.TeamsMonthly, 50, 10, 30), + new (Guid.NewGuid(), providerId, PlanType.EnterpriseMonthly, 100, 0, 90) + }; + + var subscription = new Subscription + { + Status = "active", + CurrentPeriodEnd = new DateTime(2025, 1, 1), + Customer = new Customer { Discount = new Discount { Coupon = new Coupon { PercentOff = 10 } } } + }; + + var providerSubscriptionData = new ProviderSubscriptionData( + configuredPlans, + subscription); + + sutProvider.GetDependency().GetSubscriptionData(providerId) + .Returns(providerSubscriptionData); + + var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); + + Assert.IsType>(result); + + var providerSubscriptionDTO = ((Ok)result).Value; + + Assert.Equal(providerSubscriptionDTO.Status, subscription.Status); + Assert.Equal(providerSubscriptionDTO.CurrentPeriodEndDate, subscription.CurrentPeriodEnd); + Assert.Equal(providerSubscriptionDTO.DiscountPercentage, subscription.Customer!.Discount!.Coupon!.PercentOff); + + var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var providerTeamsPlan = providerSubscriptionDTO.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name); + Assert.NotNull(providerTeamsPlan); + Assert.Equal(50, providerTeamsPlan.SeatMinimum); + Assert.Equal(10, providerTeamsPlan.PurchasedSeats); + Assert.Equal(30, providerTeamsPlan.AssignedSeats); + Assert.Equal(60 * teamsPlan.PasswordManager.SeatPrice, providerTeamsPlan.Cost); + Assert.Equal("Monthly", providerTeamsPlan.Cadence); + + var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var providerEnterprisePlan = providerSubscriptionDTO.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name); + Assert.NotNull(providerEnterprisePlan); + Assert.Equal(100, providerEnterprisePlan.SeatMinimum); + Assert.Equal(0, providerEnterprisePlan.PurchasedSeats); + Assert.Equal(90, providerEnterprisePlan.AssignedSeats); + Assert.Equal(100 * enterprisePlan.PasswordManager.SeatPrice, providerEnterprisePlan.Cost); + Assert.Equal("Monthly", providerEnterprisePlan.Cadence); + } +} diff --git a/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs new file mode 100644 index 0000000000..805683de27 --- /dev/null +++ b/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs @@ -0,0 +1,168 @@ +using Bit.Api.Billing.Controllers; +using Bit.Api.Billing.Models; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Commands; +using Bit.Core.Context; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Infrastructure.EntityFramework.AdminConsole.Models.Provider; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Http.HttpResults; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Xunit; +using ProviderOrganization = Bit.Core.AdminConsole.Entities.Provider.ProviderOrganization; + +namespace Bit.Api.Test.Billing.Controllers; + +[ControllerCustomize(typeof(ProviderOrganizationController))] +[SutProviderCustomize] +public class ProviderOrganizationControllerTests +{ + [Theory, BitAutoData] + public async Task UpdateAsync_FFDisabled_NotFound( + Guid providerId, + Guid providerOrganizationId, + UpdateProviderOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(false); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized( + Guid providerId, + Guid providerOrganizationId, + UpdateProviderOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(false); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionAsync_NoProvider_NotFound( + Guid providerId, + Guid providerOrganizationId, + UpdateProviderOrganizationRequestBody requestBody, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .ReturnsNull(); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionAsync_NoProviderOrganization_NotFound( + Guid providerId, + Guid providerOrganizationId, + UpdateProviderOrganizationRequestBody requestBody, + Provider provider, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .Returns(provider); + + sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) + .ReturnsNull(); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionAsync_NoOrganization_ServerError( + Guid providerId, + Guid providerOrganizationId, + UpdateProviderOrganizationRequestBody requestBody, + Provider provider, + ProviderOrganization providerOrganization, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .Returns(provider); + + sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) + .Returns(providerOrganization); + + sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) + .ReturnsNull(); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionAsync_NoContent( + Guid providerId, + Guid providerOrganizationId, + UpdateProviderOrganizationRequestBody requestBody, + Provider provider, + ProviderOrganization providerOrganization, + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().ProviderProviderAdmin(providerId) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId) + .Returns(provider); + + sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) + .Returns(providerOrganization); + + sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) + .Returns(organization); + + var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + + await sutProvider.GetDependency().Received(1) + .AssignSeatsToClientOrganization( + provider, + organization, + requestBody.AssignedSeats); + + Assert.IsType(result); + } +} diff --git a/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs b/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs new file mode 100644 index 0000000000..918b7c47a2 --- /dev/null +++ b/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs @@ -0,0 +1,339 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing; +using Bit.Core.Billing.Commands.Implementations; +using Bit.Core.Billing.Entities; +using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Repositories; +using Bit.Core.Enums; +using Bit.Core.Models.StaticStore; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +using static Bit.Core.Test.Billing.Utilities; + +namespace Bit.Core.Test.Billing.Commands; + +[SutProviderCustomize] +public class AssignSeatsToClientOrganizationCommandTests +{ + [Theory, BitAutoData] + public Task AssignSeatsToClientOrganization_NullProvider_ArgumentNullException( + Organization organization, + int seats, + SutProvider sutProvider) + => Assert.ThrowsAsync(() => + sutProvider.Sut.AssignSeatsToClientOrganization(null, organization, seats)); + + [Theory, BitAutoData] + public Task AssignSeatsToClientOrganization_NullOrganization_ArgumentNullException( + Provider provider, + int seats, + SutProvider sutProvider) + => Assert.ThrowsAsync(() => + sutProvider.Sut.AssignSeatsToClientOrganization(provider, null, seats)); + + [Theory, BitAutoData] + public Task AssignSeatsToClientOrganization_NegativeSeats_BillingException( + Provider provider, + Organization organization, + SutProvider sutProvider) + => Assert.ThrowsAsync(() => + sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, -5)); + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_CurrentSeatsMatchesNewSeats_NoOp( + Provider provider, + Organization organization, + int seats, + SutProvider sutProvider) + { + organization.PlanType = PlanType.TeamsMonthly; + + organization.Seats = seats; + + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); + + await sutProvider.GetDependency().DidNotReceive().GetByProviderId(provider.Id); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_OrganizationPlanTypeDoesNotSupportConsolidatedBilling_ContactSupport( + Provider provider, + Organization organization, + int seats, + SutProvider sutProvider) + { + organization.PlanType = PlanType.FamiliesAnnually; + + await ThrowsContactSupportAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_ProviderPlanIsNotConfigured_ContactSupport( + Provider provider, + Organization organization, + int seats, + SutProvider sutProvider) + { + organization.PlanType = PlanType.TeamsMonthly; + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(new List + { + new () + { + Id = Guid.NewGuid(), + PlanType = PlanType.TeamsMonthly, + ProviderId = provider.Id + } + }); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_BelowToBelow_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.Seats = 10; + + organization.PlanType = PlanType.TeamsMonthly; + + // Scale up 10 seats + const int seats = 20; + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.TeamsMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + // 100 minimum + SeatMinimum = 100, + AllocatedSeats = 50 + }, + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.EnterpriseMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + SeatMinimum = 500, + AllocatedSeats = 0 + } + }; + + var providerPlan = providerPlans.First(); + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + + // 50 seats currently assigned with a seat minimum of 100 + sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(50); + + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); + + // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().AdjustSeats( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.Id == organization.Id && org.Seats == seats)); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + pPlan => pPlan.AllocatedSeats == 60)); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_BelowToAbove_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.Seats = 10; + + organization.PlanType = PlanType.TeamsMonthly; + + // Scale up 10 seats + const int seats = 20; + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.TeamsMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + // 100 minimum + SeatMinimum = 100, + AllocatedSeats = 95 + }, + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.EnterpriseMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + SeatMinimum = 500, + AllocatedSeats = 0 + } + }; + + var providerPlan = providerPlans.First(); + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + + // 95 seats currently assigned with a seat minimum of 100 + sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(95); + + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); + + // 95 current + 10 seat scale = 105 seats, 5 above the minimum + await sutProvider.GetDependency().Received(1).AdjustSeats( + provider, + StaticStore.GetPlan(providerPlan.PlanType), + providerPlan.SeatMinimum!.Value, + 105); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.Id == organization.Id && org.Seats == seats)); + + // 105 total seats - 100 minimum = 5 purchased seats + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 5 && pPlan.AllocatedSeats == 105)); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_AboveToAbove_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.Seats = 10; + + organization.PlanType = PlanType.TeamsMonthly; + + // Scale up 10 seats + const int seats = 20; + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.TeamsMonthly, + ProviderId = provider.Id, + // 10 additional purchased seats + PurchasedSeats = 10, + // 100 seat minimum + SeatMinimum = 100, + AllocatedSeats = 110 + }, + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.EnterpriseMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + SeatMinimum = 500, + AllocatedSeats = 0 + } + }; + + var providerPlan = providerPlans.First(); + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + + // 110 seats currently assigned with a seat minimum of 100 + sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(110); + + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); + + // 110 current + 10 seat scale up = 120 seats + await sutProvider.GetDependency().Received(1).AdjustSeats( + provider, + StaticStore.GetPlan(providerPlan.PlanType), + 110, + 120); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.Id == organization.Id && org.Seats == seats)); + + // 120 total seats - 100 seat minimum = 20 purchased seats + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 20 && pPlan.AllocatedSeats == 120)); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_AboveToBelow_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.Seats = 50; + + organization.PlanType = PlanType.TeamsMonthly; + + // Scale down 30 seats + const int seats = 20; + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.TeamsMonthly, + ProviderId = provider.Id, + // 10 additional purchased seats + PurchasedSeats = 10, + // 100 seat minimum + SeatMinimum = 100, + AllocatedSeats = 110 + }, + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.EnterpriseMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + SeatMinimum = 500, + AllocatedSeats = 0 + } + }; + + var providerPlan = providerPlans.First(); + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + + // 110 seats currently assigned with a seat minimum of 100 + sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(110); + + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); + + // 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum. + await sutProvider.GetDependency().Received(1).AdjustSeats( + provider, + StaticStore.GetPlan(providerPlan.PlanType), + 110, + providerPlan.SeatMinimum!.Value); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.Id == organization.Id && org.Seats == seats)); + + // Being below the seat minimum means no purchased seats. + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 0 && pPlan.AllocatedSeats == 80)); + } +} diff --git a/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs index 5de14f006f..968bfeb84d 100644 --- a/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs +++ b/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs @@ -1,13 +1,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Commands.Implementations; using Bit.Core.Enums; -using Bit.Core.Exceptions; using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using NSubstitute.ReturnsExtensions; using Xunit; +using static Bit.Core.Test.Billing.Utilities; using BT = Braintree; using S = Stripe; @@ -355,13 +355,4 @@ public class RemovePaymentMethodCommandTests return (braintreeGateway, customerGateway, paymentMethodGateway); } - - private static async Task ThrowsContactSupportAsync(Func function) - { - const string message = "Could not remove your payment method. Please contact support for assistance."; - - var exception = await Assert.ThrowsAsync(function); - - Assert.Equal(message, exception.Message); - } } diff --git a/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs b/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs deleted file mode 100644 index adae46a791..0000000000 --- a/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs +++ /dev/null @@ -1,104 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Queries.Implementations; -using Bit.Core.Entities; -using Bit.Core.Exceptions; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Stripe; -using Xunit; - -namespace Bit.Core.Test.Billing.Queries; - -[SutProviderCustomize] -public class GetSubscriptionQueryTests -{ - [Theory, BitAutoData] - public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) - => await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetSubscription(null)); - - [Theory, BitAutoData] - public async Task GetSubscription_Organization_NoGatewaySubscriptionId_ThrowsGatewayException( - Organization organization, - SutProvider sutProvider) - { - organization.GatewaySubscriptionId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(organization)); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Organization_NoSubscription_ThrowsGatewayException( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(organization)); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Organization_Succeeds( - Organization organization, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscription(organization); - - Assert.Equivalent(subscription, gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_User_NoGatewaySubscriptionId_ThrowsGatewayException( - User user, - SutProvider sutProvider) - { - user.GatewaySubscriptionId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(user)); - } - - [Theory, BitAutoData] - public async Task GetSubscription_User_NoSubscription_ThrowsGatewayException( - User user, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(user)); - } - - [Theory, BitAutoData] - public async Task GetSubscription_User_Succeeds( - User user, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscription(user); - - Assert.Equivalent(subscription, gotSubscription); - } - - private static async Task ThrowsContactSupportAsync(Func function) - { - const string message = "Something went wrong with your request. Please contact support."; - - var exception = await Assert.ThrowsAsync(function); - - Assert.Equal(message, exception.Message); - } -} diff --git a/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs b/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs new file mode 100644 index 0000000000..534444ba94 --- /dev/null +++ b/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs @@ -0,0 +1,154 @@ +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Entities; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Queries.Implementations; +using Bit.Core.Billing.Repositories; +using Bit.Core.Enums; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Queries; + +[SutProviderCustomize] +public class ProviderBillingQueriesTests +{ + #region GetSubscriptionData + + [Theory, BitAutoData] + public async Task GetSubscriptionData_NullProvider_ReturnsNull( + SutProvider sutProvider, + Guid providerId) + { + var providerRepository = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).ReturnsNull(); + + var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); + + Assert.Null(subscriptionData); + + await providerRepository.Received(1).GetByIdAsync(providerId); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionData_NullSubscription_ReturnsNull( + SutProvider sutProvider, + Guid providerId, + Provider provider) + { + var providerRepository = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).Returns(provider); + + var subscriberQueries = sutProvider.GetDependency(); + + subscriberQueries.GetSubscription(provider).ReturnsNull(); + + var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); + + Assert.Null(subscriptionData); + + await providerRepository.Received(1).GetByIdAsync(providerId); + + await subscriberQueries.Received(1).GetSubscription( + provider, + Arg.Is( + options => options.Expand.Count == 1 && options.Expand.First() == "customer")); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionData_Success( + SutProvider sutProvider, + Guid providerId, + Provider provider) + { + var providerRepository = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).Returns(provider); + + var subscriberQueries = sutProvider.GetDependency(); + + var subscription = new Subscription(); + + subscriberQueries.GetSubscription(provider, Arg.Is( + options => options.Expand.Count == 1 && options.Expand.First() == "customer")).Returns(subscription); + + var providerPlanRepository = sutProvider.GetDependency(); + + var enterprisePlan = new ProviderPlan + { + Id = Guid.NewGuid(), + ProviderId = providerId, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + }; + + var teamsPlan = new ProviderPlan + { + Id = Guid.NewGuid(), + ProviderId = providerId, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 50, + PurchasedSeats = 10, + AllocatedSeats = 60 + }; + + var providerPlans = new List + { + enterprisePlan, + teamsPlan, + }; + + providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans); + + var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); + + Assert.NotNull(subscriptionData); + + Assert.Equivalent(subscriptionData.Subscription, subscription); + + Assert.Equal(2, subscriptionData.ProviderPlans.Count); + + var configuredEnterprisePlan = + subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan => + configuredPlan.PlanType == PlanType.EnterpriseMonthly); + + var configuredTeamsPlan = + subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan => + configuredPlan.PlanType == PlanType.TeamsMonthly); + + Compare(enterprisePlan, configuredEnterprisePlan); + + Compare(teamsPlan, configuredTeamsPlan); + + await providerRepository.Received(1).GetByIdAsync(providerId); + + await subscriberQueries.Received(1).GetSubscription( + provider, + Arg.Is( + options => options.Expand.Count == 1 && options.Expand.First() == "customer")); + + await providerPlanRepository.Received(1).GetByProviderId(providerId); + + return; + + void Compare(ProviderPlan providerPlan, ConfiguredProviderPlan configuredProviderPlan) + { + Assert.NotNull(configuredProviderPlan); + Assert.Equal(providerPlan.Id, configuredProviderPlan.Id); + Assert.Equal(providerPlan.ProviderId, configuredProviderPlan.ProviderId); + Assert.Equal(providerPlan.SeatMinimum!.Value, configuredProviderPlan.SeatMinimum); + Assert.Equal(providerPlan.PurchasedSeats!.Value, configuredProviderPlan.PurchasedSeats); + Assert.Equal(providerPlan.AllocatedSeats!.Value, configuredProviderPlan.AssignedSeats); + } + } + #endregion +} diff --git a/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs b/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs new file mode 100644 index 0000000000..51682a6661 --- /dev/null +++ b/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs @@ -0,0 +1,263 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Queries.Implementations; +using Bit.Core.Entities; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; + +using static Bit.Core.Test.Billing.Utilities; + +namespace Bit.Core.Test.Billing.Queries; + +[SutProviderCustomize] +public class SubscriberQueriesTests +{ + #region GetSubscription + [Theory, BitAutoData] + public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetSubscription(null)); + + [Theory, BitAutoData] + public async Task GetSubscription_Organization_NoGatewaySubscriptionId_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = null; + + var gotSubscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Organization_NoSubscription_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .ReturnsNull(); + + var gotSubscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Organization_Succeeds( + Organization organization, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Equivalent(subscription, gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_User_NoGatewaySubscriptionId_ReturnsNull( + User user, + SutProvider sutProvider) + { + user.GatewaySubscriptionId = null; + + var gotSubscription = await sutProvider.Sut.GetSubscription(user); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_User_NoSubscription_ReturnsNull( + User user, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + .ReturnsNull(); + + var gotSubscription = await sutProvider.Sut.GetSubscription(user); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_User_Succeeds( + User user, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscription(user); + + Assert.Equivalent(subscription, gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Provider_NoGatewaySubscriptionId_ReturnsNull( + Provider provider, + SutProvider sutProvider) + { + provider.GatewaySubscriptionId = null; + + var gotSubscription = await sutProvider.Sut.GetSubscription(provider); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Provider_NoSubscription_ReturnsNull( + Provider provider, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) + .ReturnsNull(); + + var gotSubscription = await sutProvider.Sut.GetSubscription(provider); + + Assert.Null(gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Provider_Succeeds( + Provider provider, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscription(provider); + + Assert.Equivalent(subscription, gotSubscription); + } + #endregion + + #region GetSubscriptionOrThrow + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetSubscriptionOrThrow(null)); + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Organization_NoGatewaySubscriptionId_ThrowsGatewayException( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = null; + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Organization_NoSubscription_ThrowsGatewayException( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .ReturnsNull(); + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Organization_Succeeds( + Organization organization, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization); + + Assert.Equivalent(subscription, gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_User_NoGatewaySubscriptionId_ThrowsGatewayException( + User user, + SutProvider sutProvider) + { + user.GatewaySubscriptionId = null; + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_User_NoSubscription_ThrowsGatewayException( + User user, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + .ReturnsNull(); + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_User_Succeeds( + User user, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(user); + + Assert.Equivalent(subscription, gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Provider_NoGatewaySubscriptionId_ThrowsGatewayException( + Provider provider, + SutProvider sutProvider) + { + provider.GatewaySubscriptionId = null; + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Provider_NoSubscription_ThrowsGatewayException( + Provider provider, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) + .ReturnsNull(); + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Provider_Succeeds( + Provider provider, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(provider); + + Assert.Equivalent(subscription, gotSubscription); + } + #endregion +} diff --git a/test/Core.Test/Billing/Utilities.cs b/test/Core.Test/Billing/Utilities.cs index 359c010a29..ea9e6c694c 100644 --- a/test/Core.Test/Billing/Utilities.cs +++ b/test/Core.Test/Billing/Utilities.cs @@ -1,4 +1,4 @@ -using Bit.Core.Exceptions; +using Bit.Core.Billing; using Xunit; using static Bit.Core.Billing.Utilities; @@ -11,7 +11,7 @@ public static class Utilities { var contactSupport = ContactSupport(); - var exception = await Assert.ThrowsAsync(function); + var exception = await Assert.ThrowsAsync(function); Assert.Equal(contactSupport.Message, exception.Message); } diff --git a/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs b/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs index def8f5c14c..2dc23056d1 100644 --- a/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs +++ b/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs @@ -42,4 +42,23 @@ public class IdentityApplicationFactory : WebApplicationFactoryBase return (root.GetProperty("access_token").GetString(), root.GetProperty("refresh_token").GetString()); } + + public async Task TokenFromAccessTokenAsync(Guid clientId, string clientSecret, + DeviceType deviceType = DeviceType.SDK) + { + var context = await Server.PostAsync("/connect/token", + new FormUrlEncodedContent(new Dictionary + { + { "scope", "api.secrets" }, + { "client_id", clientId.ToString() }, + { "client_secret", clientSecret }, + { "grant_type", "client_credentials" }, + { "deviceType", ((int)deviceType).ToString() } + })); + + using var body = await AssertHelper.AssertResponseTypeIs(context); + var root = body.RootElement; + + return root.GetProperty("access_token").GetString(); + } } diff --git a/util/Migrator/DbMigrator.cs b/util/Migrator/DbMigrator.cs index a6ca53abdb..11b80fac78 100644 --- a/util/Migrator/DbMigrator.cs +++ b/util/Migrator/DbMigrator.cs @@ -13,11 +13,14 @@ public class DbMigrator { private readonly string _connectionString; private readonly ILogger _logger; + private readonly bool _skipDatabasePreparation; - public DbMigrator(string connectionString, ILogger logger = null) + public DbMigrator(string connectionString, ILogger logger = null, + bool skipDatabasePreparation = false) { _connectionString = connectionString; _logger = logger ?? CreateLogger(); + _skipDatabasePreparation = skipDatabasePreparation; } public bool MigrateMsSqlDatabaseWithRetries(bool enableLogging = true, @@ -31,7 +34,10 @@ public class DbMigrator { try { - PrepareDatabase(cancellationToken); + if (!_skipDatabasePreparation) + { + PrepareDatabase(cancellationToken); + } var success = MigrateDatabase(enableLogging, repeatable, folderName, dryRun, cancellationToken); return success; diff --git a/util/Migrator/SqlServerDbMigrator.cs b/util/Migrator/SqlServerDbMigrator.cs index b443260820..d76b26cfb7 100644 --- a/util/Migrator/SqlServerDbMigrator.cs +++ b/util/Migrator/SqlServerDbMigrator.cs @@ -10,7 +10,8 @@ public class SqlServerDbMigrator : IDbMigrator public SqlServerDbMigrator(GlobalSettings globalSettings, ILogger logger) { - _migrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, logger); + _migrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, logger, + globalSettings.SqlServer.SkipDatabasePreparation); } public bool MigrateDatabase(bool enableLogging = true,