diff --git a/.github/renovate.json b/.github/renovate.json index 91774ca33e..18d6e0bb61 100644 --- a/.github/renovate.json +++ b/.github/renovate.json @@ -44,7 +44,6 @@ "matchPackageNames": [ "AspNetCoreRateLimit", "AspNetCoreRateLimit.Redis", - "Azure.Data.Tables", "Azure.Extensions.AspNetCore.DataProtection.Blobs", "Azure.Messaging.EventGrid", "Azure.Messaging.ServiceBus", @@ -54,6 +53,7 @@ "Fido2.AspNet", "Duende.IdentityServer", "Microsoft.Azure.Cosmos", + "Microsoft.Azure.Cosmos.Table", "Microsoft.Extensions.Caching.StackExchangeRedis", "Microsoft.Extensions.Identity.Stores", "Otp.NET", diff --git a/.github/workflows/scan.yml b/.github/workflows/scan.yml index 89d75ccf0f..62203804b9 100644 --- a/.github/workflows/scan.yml +++ b/.github/workflows/scan.yml @@ -10,6 +10,8 @@ on: pull_request_target: types: [opened, synchronize] +permissions: read-all + jobs: check-run: name: Check PR run @@ -20,8 +22,6 @@ jobs: runs-on: ubuntu-22.04 needs: check-run permissions: - contents: read - pull-requests: write security-events: write steps: @@ -43,7 +43,7 @@ jobs: additional_params: --report-format sarif --output-path . ${{ env.INCREMENTAL }} - name: Upload Checkmarx results to GitHub - uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9 + uses: github/codeql-action/upload-sarif@8a470fddafa5cbb6266ee11b37ef4d8aae19c571 # v3.24.6 with: sarif_file: cx_result.sarif @@ -51,9 +51,6 @@ jobs: name: Quality scan runs-on: ubuntu-22.04 needs: check-run - permissions: - contents: read - pull-requests: write steps: - name: Check out repo diff --git a/bitwarden_license/src/Scim/appsettings.json b/bitwarden_license/src/Scim/appsettings.json index dcdfeb3ede..630896a65f 100644 --- a/bitwarden_license/src/Scim/appsettings.json +++ b/bitwarden_license/src/Scim/appsettings.json @@ -30,6 +30,10 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, + "documentDb": { + "uri": "SECRET", + "key": "SECRET" + }, "sentry": { "dsn": "SECRET" }, @@ -54,5 +58,6 @@ "region": "SECRET" } }, - "scimSettings": {} + "scimSettings": { + } } diff --git a/bitwarden_license/src/Sso/appsettings.json b/bitwarden_license/src/Sso/appsettings.json index 73c85044cc..3bf02cd869 100644 --- a/bitwarden_license/src/Sso/appsettings.json +++ b/bitwarden_license/src/Sso/appsettings.json @@ -31,6 +31,10 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, + "documentDb": { + "uri": "SECRET", + "key": "SECRET" + }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" diff --git a/src/Admin/Jobs/JobsHostedService.cs b/src/Admin/Jobs/JobsHostedService.cs index 89cf5512c3..adba27970c 100644 --- a/src/Admin/Jobs/JobsHostedService.cs +++ b/src/Admin/Jobs/JobsHostedService.cs @@ -76,18 +76,14 @@ public class JobsHostedService : BaseJobsHostedService { new Tuple(typeof(DeleteSendsJob), everyFiveMinutesTrigger), new Tuple(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger), + new Tuple(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger), + new Tuple(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger), new Tuple(typeof(DeleteCiphersJob), everyDayAtMidnightUtc), new Tuple(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger), new Tuple(typeof(DeleteAuthRequestsJob), everyFifteenMinutesTrigger), new Tuple(typeof(DeleteUnverifiedOrganizationDomainsJob), everyDayAtTwoAmUtcTrigger), }; - if (!(_globalSettings.SqlServer?.DisableDatabaseMaintenanceJobs ?? false)) - { - jobs.Add(new Tuple(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger)); - jobs.Add(new Tuple(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger)); - } - if (!_globalSettings.SelfHosted) { jobs.Add(new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger)); diff --git a/src/Admin/Startup.cs b/src/Admin/Startup.cs index 788908d42a..db870266cc 100644 --- a/src/Admin/Startup.cs +++ b/src/Admin/Startup.cs @@ -88,7 +88,7 @@ public class Startup services.AddBaseServices(globalSettings); services.AddDefaultServices(globalSettings); services.AddScoped(); - services.AddBillingOperations(); + services.AddBillingCommands(); #if OSS services.AddOosServices(); diff --git a/src/Admin/appsettings.json b/src/Admin/appsettings.json index 9513dc44a2..4764484204 100644 --- a/src/Admin/appsettings.json +++ b/src/Admin/appsettings.json @@ -30,6 +30,10 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, + "documentDb": { + "uri": "SECRET", + "key": "SECRET" + }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" diff --git a/src/Api/AdminConsole/Controllers/OrganizationsController.cs b/src/Api/AdminConsole/Controllers/OrganizationsController.cs index 822f9635eb..2a4ba3a1db 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationsController.cs @@ -66,7 +66,7 @@ public class OrganizationsController : Controller private readonly IAddSecretsManagerSubscriptionCommand _addSecretsManagerSubscriptionCommand; private readonly IPushNotificationService _pushNotificationService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly ISubscriberQueries _subscriberQueries; + private readonly IGetSubscriptionQuery _getSubscriptionQuery; private readonly IReferenceEventService _referenceEventService; private readonly IOrganizationEnableCollectionEnhancementsCommand _organizationEnableCollectionEnhancementsCommand; @@ -93,7 +93,7 @@ public class OrganizationsController : Controller IAddSecretsManagerSubscriptionCommand addSecretsManagerSubscriptionCommand, IPushNotificationService pushNotificationService, ICancelSubscriptionCommand cancelSubscriptionCommand, - ISubscriberQueries subscriberQueries, + IGetSubscriptionQuery getSubscriptionQuery, IReferenceEventService referenceEventService, IOrganizationEnableCollectionEnhancementsCommand organizationEnableCollectionEnhancementsCommand) { @@ -119,7 +119,7 @@ public class OrganizationsController : Controller _addSecretsManagerSubscriptionCommand = addSecretsManagerSubscriptionCommand; _pushNotificationService = pushNotificationService; _cancelSubscriptionCommand = cancelSubscriptionCommand; - _subscriberQueries = subscriberQueries; + _getSubscriptionQuery = getSubscriptionQuery; _referenceEventService = referenceEventService; _organizationEnableCollectionEnhancementsCommand = organizationEnableCollectionEnhancementsCommand; } @@ -479,7 +479,7 @@ public class OrganizationsController : Controller throw new NotFoundException(); } - var subscription = await _subscriberQueries.GetSubscriptionOrThrow(organization); + var subscription = await _getSubscriptionQuery.GetSubscription(organization); await _cancelSubscriptionCommand.CancelSubscription(subscription, new OffboardingSurveyResponse diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs index 767f83ee22..df75a34f69 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs @@ -126,14 +126,8 @@ public class OrganizationSubscriptionResponseModel : OrganizationResponseModel if (hideSensitiveData) { BillingEmail = null; - if (Subscription != null) - { - Subscription.Items = null; - } - if (UpcomingInvoice != null) - { - UpcomingInvoice.Amount = null; - } + Subscription.Items = null; + UpcomingInvoice.Amount = null; } } diff --git a/src/Api/Auth/Controllers/AccountsController.cs b/src/Api/Auth/Controllers/AccountsController.cs index 5f1910fb28..29ede684be 100644 --- a/src/Api/Auth/Controllers/AccountsController.cs +++ b/src/Api/Auth/Controllers/AccountsController.cs @@ -69,7 +69,7 @@ public class AccountsController : Controller private readonly IRotateUserKeyCommand _rotateUserKeyCommand; private readonly IFeatureService _featureService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly ISubscriberQueries _subscriberQueries; + private readonly IGetSubscriptionQuery _getSubscriptionQuery; private readonly IReferenceEventService _referenceEventService; private readonly ICurrentContext _currentContext; @@ -104,7 +104,7 @@ public class AccountsController : Controller IRotateUserKeyCommand rotateUserKeyCommand, IFeatureService featureService, ICancelSubscriptionCommand cancelSubscriptionCommand, - ISubscriberQueries subscriberQueries, + IGetSubscriptionQuery getSubscriptionQuery, IReferenceEventService referenceEventService, ICurrentContext currentContext, IRotationValidator, IEnumerable> cipherValidator, @@ -133,7 +133,7 @@ public class AccountsController : Controller _rotateUserKeyCommand = rotateUserKeyCommand; _featureService = featureService; _cancelSubscriptionCommand = cancelSubscriptionCommand; - _subscriberQueries = subscriberQueries; + _getSubscriptionQuery = getSubscriptionQuery; _referenceEventService = referenceEventService; _currentContext = currentContext; _cipherValidator = cipherValidator; @@ -831,7 +831,7 @@ public class AccountsController : Controller throw new UnauthorizedAccessException(); } - var subscription = await _subscriberQueries.GetSubscriptionOrThrow(user); + var subscription = await _getSubscriptionQuery.GetSubscription(user); await _cancelSubscriptionCommand.CancelSubscription(subscription, new OffboardingSurveyResponse diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs deleted file mode 100644 index 583a5937e4..0000000000 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ /dev/null @@ -1,44 +0,0 @@ -using Bit.Api.Billing.Models; -using Bit.Core; -using Bit.Core.Billing.Queries; -using Bit.Core.Context; -using Bit.Core.Services; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Billing.Controllers; - -[Route("providers/{providerId:guid}/billing")] -[Authorize("Application")] -public class ProviderBillingController( - ICurrentContext currentContext, - IFeatureService featureService, - IProviderBillingQueries providerBillingQueries) : Controller -{ - [HttpGet("subscription")] - public async Task GetSubscriptionAsync([FromRoute] Guid providerId) - { - if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) - { - return TypedResults.NotFound(); - } - - if (!currentContext.ProviderProviderAdmin(providerId)) - { - return TypedResults.Unauthorized(); - } - - var subscriptionData = await providerBillingQueries.GetSubscriptionData(providerId); - - if (subscriptionData == null) - { - return TypedResults.NotFound(); - } - - var (providerPlans, subscription) = subscriptionData; - - var providerSubscriptionDTO = ProviderSubscriptionDTO.From(providerPlans, subscription); - - return TypedResults.Ok(providerSubscriptionDTO); - } -} diff --git a/src/Api/Billing/Controllers/ProviderOrganizationController.cs b/src/Api/Billing/Controllers/ProviderOrganizationController.cs deleted file mode 100644 index a5cc31c79c..0000000000 --- a/src/Api/Billing/Controllers/ProviderOrganizationController.cs +++ /dev/null @@ -1,63 +0,0 @@ -using Bit.Api.Billing.Models; -using Bit.Core; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Commands; -using Bit.Core.Context; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Billing.Controllers; - -[Route("providers/{providerId:guid}/organizations")] -public class ProviderOrganizationController( - IAssignSeatsToClientOrganizationCommand assignSeatsToClientOrganizationCommand, - ICurrentContext currentContext, - IFeatureService featureService, - ILogger logger, - IOrganizationRepository organizationRepository, - IProviderRepository providerRepository, - IProviderOrganizationRepository providerOrganizationRepository) : Controller -{ - [HttpPut("{providerOrganizationId:guid}")] - public async Task UpdateAsync( - [FromRoute] Guid providerId, - [FromRoute] Guid providerOrganizationId, - [FromBody] UpdateProviderOrganizationRequestBody requestBody) - { - if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) - { - return TypedResults.NotFound(); - } - - if (!currentContext.ProviderProviderAdmin(providerId)) - { - return TypedResults.Unauthorized(); - } - - var provider = await providerRepository.GetByIdAsync(providerId); - - var providerOrganization = await providerOrganizationRepository.GetByIdAsync(providerOrganizationId); - - if (provider == null || providerOrganization == null) - { - return TypedResults.NotFound(); - } - - var organization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId); - - if (organization == null) - { - logger.LogError("The organization ({OrganizationID}) represented by provider organization ({ProviderOrganizationID}) could not be found.", providerOrganization.OrganizationId, providerOrganization.Id); - - return TypedResults.Problem(); - } - - await assignSeatsToClientOrganizationCommand.AssignSeatsToClientOrganization( - provider, - organization, - requestBody.AssignedSeats); - - return TypedResults.Ok(); - } -} diff --git a/src/Api/Billing/Models/ProviderSubscriptionDTO.cs b/src/Api/Billing/Models/ProviderSubscriptionDTO.cs deleted file mode 100644 index ad0714967d..0000000000 --- a/src/Api/Billing/Models/ProviderSubscriptionDTO.cs +++ /dev/null @@ -1,49 +0,0 @@ -using Bit.Core.Billing.Models; -using Bit.Core.Utilities; -using Stripe; - -namespace Bit.Api.Billing.Models; - -public record ProviderSubscriptionDTO( - string Status, - DateTime CurrentPeriodEndDate, - decimal? DiscountPercentage, - IEnumerable Plans) -{ - private const string _annualCadence = "Annual"; - private const string _monthlyCadence = "Monthly"; - - public static ProviderSubscriptionDTO From( - IEnumerable providerPlans, - Subscription subscription) - { - var providerPlansDTO = providerPlans - .Select(providerPlan => - { - var plan = StaticStore.GetPlan(providerPlan.PlanType); - var cost = (providerPlan.SeatMinimum + providerPlan.PurchasedSeats) * plan.PasswordManager.SeatPrice; - var cadence = plan.IsAnnual ? _annualCadence : _monthlyCadence; - return new ProviderPlanDTO( - plan.Name, - providerPlan.SeatMinimum, - providerPlan.PurchasedSeats, - providerPlan.AssignedSeats, - cost, - cadence); - }); - - return new ProviderSubscriptionDTO( - subscription.Status, - subscription.CurrentPeriodEnd, - subscription.Customer?.Discount?.Coupon?.PercentOff, - providerPlansDTO); - } -} - -public record ProviderPlanDTO( - string PlanName, - int SeatMinimum, - int PurchasedSeats, - int AssignedSeats, - decimal Cost, - string Cadence); diff --git a/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs b/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs deleted file mode 100644 index 7bac8fdef4..0000000000 --- a/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Bit.Api.Billing.Models; - -public class UpdateProviderOrganizationRequestBody -{ - public int AssignedSeats { get; set; } -} diff --git a/src/Api/Controllers/CollectionsController.cs b/src/Api/Controllers/CollectionsController.cs index 7711e44220..3eeae17a50 100644 --- a/src/Api/Controllers/CollectionsController.cs +++ b/src/Api/Controllers/CollectionsController.cs @@ -2,6 +2,7 @@ using Bit.Api.Models.Response; using Bit.Api.Utilities; using Bit.Api.Vault.AuthorizationHandlers.Collections; +using Bit.Core; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -10,6 +11,7 @@ using Bit.Core.Models.Data; using Bit.Core.OrganizationFeatures.OrganizationCollections.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -320,6 +322,7 @@ public class CollectionsController : Controller } [HttpPost("bulk-access")] + [RequireFeature(FeatureFlagKeys.BulkCollectionAccess)] public async Task PostBulkCollectionAccess(Guid orgId, [FromBody] BulkCollectionAccessRequestModel model) { // Authorization logic assumes flexible collections is enabled diff --git a/src/Api/Models/Response/SubscriptionResponseModel.cs b/src/Api/Models/Response/SubscriptionResponseModel.cs index cca4f8ae72..7ba2b857eb 100644 --- a/src/Api/Models/Response/SubscriptionResponseModel.cs +++ b/src/Api/Models/Response/SubscriptionResponseModel.cs @@ -75,10 +75,6 @@ public class BillingSubscription { Items = sub.Items.Select(i => new BillingSubscriptionItem(i)); } - CollectionMethod = sub.CollectionMethod; - SuspensionDate = sub.SuspensionDate; - UnpaidPeriodEndDate = sub.UnpaidPeriodEndDate; - GracePeriod = sub.GracePeriod; } public DateTime? TrialStartDate { get; set; } @@ -90,10 +86,6 @@ public class BillingSubscription public string Status { get; set; } public bool Cancelled { get; set; } public IEnumerable Items { get; set; } = new List(); - public string CollectionMethod { get; set; } - public DateTime? SuspensionDate { get; set; } - public DateTime? UnpaidPeriodEndDate { get; set; } - public int? GracePeriod { get; set; } public class BillingSubscriptionItem { diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index 63b1a3c3cd..9f94325513 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -170,7 +170,8 @@ public class Startup services.AddDefaultServices(globalSettings); services.AddOrganizationSubscriptionServices(); services.AddCoreLocalizationServices(); - services.AddBillingOperations(); + services.AddBillingCommands(); + services.AddBillingQueries(); // Authorization Handlers services.AddAuthorizationHandlers(); diff --git a/src/Api/appsettings.json b/src/Api/appsettings.json index c04539a9fe..e49491857f 100644 --- a/src/Api/appsettings.json +++ b/src/Api/appsettings.json @@ -32,6 +32,10 @@ "send": { "connectionString": "SECRET" }, + "documentDb": { + "uri": "SECRET", + "key": "SECRET" + }, "sentry": { "dsn": "SECRET" }, diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs index 679dea15ce..e78ed31ff1 100644 --- a/src/Billing/Controllers/StripeController.cs +++ b/src/Billing/Controllers/StripeController.cs @@ -868,7 +868,7 @@ public class StripeController : Controller private bool UnpaidAutoChargeInvoiceForSubscriptionCycle(Invoice invoice) { return invoice.AmountDue > 0 && !invoice.Paid && invoice.CollectionMethod == "charge_automatically" && - invoice.BillingReason is "subscription_cycle" or "automatic_pending_invoice_item_invoice" && invoice.SubscriptionId != null; + invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null; } private async Task VerifyCorrectTaxRateForCharge(Invoice invoice, Subscription subscription) diff --git a/src/Billing/appsettings.json b/src/Billing/appsettings.json index 4985784573..93d103aa80 100644 --- a/src/Billing/appsettings.json +++ b/src/Billing/appsettings.json @@ -30,6 +30,10 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, + "documentDb": { + "uri": "SECRET", + "key": "SECRET" + }, "sentry": { "dsn": "SECRET" }, diff --git a/src/Core/AdminConsole/Entities/Provider/Provider.cs b/src/Core/AdminConsole/Entities/Provider/Provider.cs index e5b794e6b1..ee2b35ed90 100644 --- a/src/Core/AdminConsole/Entities/Provider/Provider.cs +++ b/src/Core/AdminConsole/Entities/Provider/Provider.cs @@ -6,7 +6,7 @@ using Bit.Core.Utilities; namespace Bit.Core.AdminConsole.Entities.Provider; -public class Provider : ITableObject, ISubscriber +public class Provider : ITableObject { public Guid Id { get; set; } /// @@ -34,26 +34,6 @@ public class Provider : ITableObject, ISubscriber public string GatewayCustomerId { get; set; } public string GatewaySubscriptionId { get; set; } - public string BillingEmailAddress() => BillingEmail?.ToLowerInvariant().Trim(); - - public string BillingName() => DisplayBusinessName(); - - public string SubscriberName() => DisplayName(); - - public string BraintreeCustomerIdPrefix() => "p"; - - public string BraintreeIdField() => "provider_id"; - - public string BraintreeCloudRegionField() => "region"; - - public bool IsOrganization() => false; - - public bool IsUser() => false; - - public string SubscriberType() => "Provider"; - - public bool IsExpired() => false; - public void SetNewId() { if (Id == default) diff --git a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs index 80a16e495a..39cc5a1d98 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs @@ -146,8 +146,7 @@ public class SelfHostedOrganizationDetails : Organization OwnersNotifiedOfAutoscaling = OwnersNotifiedOfAutoscaling, LimitCollectionCreationDeletion = LimitCollectionCreationDeletion, AllowAdminAccessToAllCollectionItems = AllowAdminAccessToAllCollectionItems, - FlexibleCollections = FlexibleCollections, - Status = Status + FlexibleCollections = FlexibleCollections }; } } diff --git a/src/Core/Billing/BillingException.cs b/src/Core/Billing/BillingException.cs deleted file mode 100644 index a6944b3ed6..0000000000 --- a/src/Core/Billing/BillingException.cs +++ /dev/null @@ -1,9 +0,0 @@ -namespace Bit.Core.Billing; - -public class BillingException( - string clientFriendlyMessage, - string internalMessage = null, - Exception innerException = null) : Exception(internalMessage, innerException) -{ - public string ClientFriendlyMessage { get; set; } = clientFriendlyMessage; -} diff --git a/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs b/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs deleted file mode 100644 index db21926bec..0000000000 --- a/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs +++ /dev/null @@ -1,12 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; - -namespace Bit.Core.Billing.Commands; - -public interface IAssignSeatsToClientOrganizationCommand -{ - Task AssignSeatsToClientOrganization( - Provider provider, - Organization organization, - int seats); -} diff --git a/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs b/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs index 88708d3d2e..b23880e650 100644 --- a/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs +++ b/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Models; using Bit.Core.Entities; +using Bit.Core.Exceptions; using Stripe; namespace Bit.Core.Billing.Commands; @@ -16,6 +17,7 @@ public interface ICancelSubscriptionCommand /// The or with the subscription to cancel. /// An DTO containing user-provided feedback on why they are cancelling the subscription. /// A flag indicating whether to cancel the subscription immediately or at the end of the subscription period. + /// Thrown when the provided subscription is already in an inactive state. Task CancelSubscription( Subscription subscription, OffboardingSurveyResponse offboardingSurveyResponse, diff --git a/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs b/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs index e2be6f45eb..62bf0d0926 100644 --- a/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs +++ b/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs @@ -4,12 +4,5 @@ namespace Bit.Core.Billing.Commands; public interface IRemovePaymentMethodCommand { - /// - /// Attempts to remove an Organization's saved payment method. If the Stripe representing the - /// contains a valid "btCustomerId" key in its property, - /// this command will attempt to remove the Braintree . Otherwise, it will attempt to remove the - /// Stripe . - /// - /// The organization to remove the saved payment method for. Task RemovePaymentMethod(Organization organization); } diff --git a/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs b/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs deleted file mode 100644 index be2c6be968..0000000000 --- a/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs +++ /dev/null @@ -1,174 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.Billing.Entities; -using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Queries; -using Bit.Core.Billing.Repositories; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Utilities; -using Microsoft.Extensions.Logging; -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Commands.Implementations; - -public class AssignSeatsToClientOrganizationCommand( - ILogger logger, - IOrganizationRepository organizationRepository, - IPaymentService paymentService, - IProviderBillingQueries providerBillingQueries, - IProviderPlanRepository providerPlanRepository) : IAssignSeatsToClientOrganizationCommand -{ - public async Task AssignSeatsToClientOrganization( - Provider provider, - Organization organization, - int seats) - { - ArgumentNullException.ThrowIfNull(provider); - ArgumentNullException.ThrowIfNull(organization); - - if (provider.Type == ProviderType.Reseller) - { - logger.LogError("Reseller-type provider ({ID}) cannot assign seats to client organizations", provider.Id); - - throw ContactSupport("Consolidated billing does not support reseller-type providers"); - } - - if (seats < 0) - { - throw new BillingException( - "You cannot assign negative seats to a client.", - "MSP cannot assign negative seats to a client organization"); - } - - if (seats == organization.Seats) - { - logger.LogWarning("Client organization ({ID}) already has {Seats} seats assigned", organization.Id, organization.Seats); - - return; - } - - var providerPlan = await GetProviderPlanAsync(provider, organization); - - var providerSeatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0); - - // How many seats the provider has assigned to all their client organizations that have the specified plan type. - var providerCurrentlyAssignedSeatTotal = await providerBillingQueries.GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType); - - // How many seats are being added to or subtracted from this client organization. - var seatDifference = seats - (organization.Seats ?? 0); - - // How many seats the provider will have assigned to all of their client organizations after the update. - var providerNewlyAssignedSeatTotal = providerCurrentlyAssignedSeatTotal + seatDifference; - - var update = CurryUpdateFunction( - provider, - providerPlan, - organization, - seats, - providerNewlyAssignedSeatTotal); - - /* - * Below the limit => Below the limit: - * No subscription update required. We can safely update the organization's seats. - */ - if (providerCurrentlyAssignedSeatTotal <= providerSeatMinimum && - providerNewlyAssignedSeatTotal <= providerSeatMinimum) - { - organization.Seats = seats; - - await organizationRepository.ReplaceAsync(organization); - - providerPlan.AllocatedSeats = providerNewlyAssignedSeatTotal; - - await providerPlanRepository.ReplaceAsync(providerPlan); - } - /* - * Below the limit => Above the limit: - * We have to scale the subscription up from the seat minimum to the newly assigned seat total. - */ - else if (providerCurrentlyAssignedSeatTotal <= providerSeatMinimum && - providerNewlyAssignedSeatTotal > providerSeatMinimum) - { - await update( - providerSeatMinimum, - providerNewlyAssignedSeatTotal); - } - /* - * Above the limit => Above the limit: - * We have to scale the subscription from the currently assigned seat total to the newly assigned seat total. - */ - else if (providerCurrentlyAssignedSeatTotal > providerSeatMinimum && - providerNewlyAssignedSeatTotal > providerSeatMinimum) - { - await update( - providerCurrentlyAssignedSeatTotal, - providerNewlyAssignedSeatTotal); - } - /* - * Above the limit => Below the limit: - * We have to scale the subscription down from the currently assigned seat total to the seat minimum. - */ - else if (providerCurrentlyAssignedSeatTotal > providerSeatMinimum && - providerNewlyAssignedSeatTotal <= providerSeatMinimum) - { - await update( - providerCurrentlyAssignedSeatTotal, - providerSeatMinimum); - } - } - - // ReSharper disable once SuggestBaseTypeForParameter - private async Task GetProviderPlanAsync(Provider provider, Organization organization) - { - if (!organization.PlanType.SupportsConsolidatedBilling()) - { - logger.LogError("Cannot assign seats to a client organization ({ID}) with a plan type that does not support consolidated billing: {PlanType}", organization.Id, organization.PlanType); - - throw ContactSupport(); - } - - var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - - var providerPlan = providerPlans.FirstOrDefault(providerPlan => providerPlan.PlanType == organization.PlanType); - - if (providerPlan != null && providerPlan.IsConfigured()) - { - return providerPlan; - } - - logger.LogError("Cannot assign seats to client organization ({ClientOrganizationID}) when provider's ({ProviderID}) matching plan is not configured", organization.Id, provider.Id); - - throw ContactSupport(); - } - - private Func CurryUpdateFunction( - Provider provider, - ProviderPlan providerPlan, - Organization organization, - int organizationNewlyAssignedSeats, - int providerNewlyAssignedSeats) => async (providerCurrentlySubscribedSeats, providerNewlySubscribedSeats) => - { - var plan = StaticStore.GetPlan(providerPlan.PlanType); - - await paymentService.AdjustSeats( - provider, - plan, - providerCurrentlySubscribedSeats, - providerNewlySubscribedSeats); - - organization.Seats = organizationNewlyAssignedSeats; - - await organizationRepository.ReplaceAsync(organization); - - var providerNewlyPurchasedSeats = providerNewlySubscribedSeats > providerPlan.SeatMinimum - ? providerNewlySubscribedSeats - providerPlan.SeatMinimum - : 0; - - providerPlan.PurchasedSeats = providerNewlyPurchasedSeats; - providerPlan.AllocatedSeats = providerNewlyAssignedSeats; - - await providerPlanRepository.ReplaceAsync(providerPlan); - }; -} diff --git a/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs b/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs index be8479ea99..c5dbb6d927 100644 --- a/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs +++ b/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs @@ -1,41 +1,55 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Enums; +using Bit.Core.Exceptions; using Bit.Core.Services; using Braintree; using Microsoft.Extensions.Logging; -using static Bit.Core.Billing.Utilities; - namespace Bit.Core.Billing.Commands.Implementations; -public class RemovePaymentMethodCommand( - IBraintreeGateway braintreeGateway, - ILogger logger, - IStripeAdapter stripeAdapter) - : IRemovePaymentMethodCommand +public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand { + private readonly IBraintreeGateway _braintreeGateway; + private readonly ILogger _logger; + private readonly IStripeAdapter _stripeAdapter; + + public RemovePaymentMethodCommand( + IBraintreeGateway braintreeGateway, + ILogger logger, + IStripeAdapter stripeAdapter) + { + _braintreeGateway = braintreeGateway; + _logger = logger; + _stripeAdapter = stripeAdapter; + } + public async Task RemovePaymentMethod(Organization organization) { - ArgumentNullException.ThrowIfNull(organization); + const string braintreeCustomerIdKey = "btCustomerId"; + + if (organization == null) + { + throw new ArgumentNullException(nameof(organization)); + } if (organization.Gateway is not GatewayType.Stripe || string.IsNullOrEmpty(organization.GatewayCustomerId)) { throw ContactSupport(); } - var stripeCustomer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, new Stripe.CustomerGetOptions + var stripeCustomer = await _stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, new Stripe.CustomerGetOptions { - Expand = ["invoice_settings.default_payment_method", "sources"] + Expand = new List { "invoice_settings.default_payment_method", "sources" } }); if (stripeCustomer == null) { - logger.LogError("Could not find Stripe customer ({ID}) when removing payment method", organization.GatewayCustomerId); + _logger.LogError("Could not find Stripe customer ({ID}) when removing payment method", organization.GatewayCustomerId); throw ContactSupport(); } - if (stripeCustomer.Metadata?.TryGetValue(BraintreeCustomerIdKey, out var braintreeCustomerId) ?? false) + if (stripeCustomer.Metadata?.TryGetValue(braintreeCustomerIdKey, out var braintreeCustomerId) ?? false) { await RemoveBraintreePaymentMethodAsync(braintreeCustomerId); } @@ -47,11 +61,11 @@ public class RemovePaymentMethodCommand( private async Task RemoveBraintreePaymentMethodAsync(string braintreeCustomerId) { - var customer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); + var customer = await _braintreeGateway.Customer.FindAsync(braintreeCustomerId); if (customer == null) { - logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId); + _logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId); throw ContactSupport(); } @@ -60,27 +74,27 @@ public class RemovePaymentMethodCommand( { var existingDefaultPaymentMethod = customer.DefaultPaymentMethod; - var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync( + var updateCustomerResult = await _braintreeGateway.Customer.UpdateAsync( braintreeCustomerId, new CustomerRequest { DefaultPaymentMethodToken = null }); if (!updateCustomerResult.IsSuccess()) { - logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}", + _logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}", braintreeCustomerId, updateCustomerResult.Message); throw ContactSupport(); } - var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token); + var deletePaymentMethodResult = await _braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token); if (!deletePaymentMethodResult.IsSuccess()) { - await braintreeGateway.Customer.UpdateAsync( + await _braintreeGateway.Customer.UpdateAsync( braintreeCustomerId, new CustomerRequest { DefaultPaymentMethodToken = existingDefaultPaymentMethod.Token }); - logger.LogError( + _logger.LogError( "Failed to delete Braintree payment method for Customer ({ID}), re-linked payment method. Message: {Message}", braintreeCustomerId, deletePaymentMethodResult.Message); @@ -89,7 +103,7 @@ public class RemovePaymentMethodCommand( } else { - logger.LogWarning("Tried to remove non-existent Braintree payment method for Customer ({ID})", braintreeCustomerId); + _logger.LogWarning("Tried to remove non-existent Braintree payment method for Customer ({ID})", braintreeCustomerId); } } @@ -102,23 +116,25 @@ public class RemovePaymentMethodCommand( switch (source) { case Stripe.BankAccount: - await stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); + await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); break; case Stripe.Card: - await stripeAdapter.CardDeleteAsync(customer.Id, source.Id); + await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id); break; } } } - var paymentMethods = stripeAdapter.PaymentMethodListAutoPagingAsync(new Stripe.PaymentMethodListOptions + var paymentMethods = _stripeAdapter.PaymentMethodListAutoPagingAsync(new Stripe.PaymentMethodListOptions { Customer = customer.Id }); await foreach (var paymentMethod in paymentMethods) { - await stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id, new Stripe.PaymentMethodDetachOptions()); + await _stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id, new Stripe.PaymentMethodDetachOptions()); } } + + private static GatewayException ContactSupport() => new("Could not remove your payment method. Please contact support for assistance."); } diff --git a/src/Core/Billing/Entities/ProviderPlan.cs b/src/Core/Billing/Entities/ProviderPlan.cs index f4965570d9..325dbbb156 100644 --- a/src/Core/Billing/Entities/ProviderPlan.cs +++ b/src/Core/Billing/Entities/ProviderPlan.cs @@ -20,6 +20,4 @@ public class ProviderPlan : ITableObject Id = CoreHelpers.GenerateComb(); } } - - public bool IsConfigured() => SeatMinimum.HasValue && PurchasedSeats.HasValue && AllocatedSeats.HasValue; } diff --git a/src/Core/Billing/Extensions/BillingExtensions.cs b/src/Core/Billing/Extensions/BillingExtensions.cs deleted file mode 100644 index c7abeb81e2..0000000000 --- a/src/Core/Billing/Extensions/BillingExtensions.cs +++ /dev/null @@ -1,9 +0,0 @@ -using Bit.Core.Enums; - -namespace Bit.Core.Billing.Extensions; - -public static class BillingExtensions -{ - public static bool SupportsConsolidatedBilling(this PlanType planType) - => planType is PlanType.TeamsMonthly or PlanType.EnterpriseMonthly; -} diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 8e28b23397..113fa4d5b7 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -9,15 +9,14 @@ using Microsoft.Extensions.DependencyInjection; public static class ServiceCollectionExtensions { - public static void AddBillingOperations(this IServiceCollection services) + public static void AddBillingCommands(this IServiceCollection services) { - // Queries - services.AddTransient(); - services.AddTransient(); + services.AddSingleton(); + services.AddSingleton(); + } - // Commands - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); + public static void AddBillingQueries(this IServiceCollection services) + { + services.AddSingleton(); } } diff --git a/src/Core/Billing/Models/ConfiguredProviderPlan.cs b/src/Core/Billing/Models/ConfiguredProviderPlan.cs deleted file mode 100644 index d6bc2b7522..0000000000 --- a/src/Core/Billing/Models/ConfiguredProviderPlan.cs +++ /dev/null @@ -1,24 +0,0 @@ -using Bit.Core.Billing.Entities; -using Bit.Core.Enums; - -namespace Bit.Core.Billing.Models; - -public record ConfiguredProviderPlan( - Guid Id, - Guid ProviderId, - PlanType PlanType, - int SeatMinimum, - int PurchasedSeats, - int AssignedSeats) -{ - public static ConfiguredProviderPlan From(ProviderPlan providerPlan) => - providerPlan.IsConfigured() - ? new ConfiguredProviderPlan( - providerPlan.Id, - providerPlan.ProviderId, - providerPlan.PlanType, - providerPlan.SeatMinimum.GetValueOrDefault(0), - providerPlan.PurchasedSeats.GetValueOrDefault(0), - providerPlan.AllocatedSeats.GetValueOrDefault(0)) - : null; -} diff --git a/src/Core/Billing/Models/ProviderSubscriptionData.cs b/src/Core/Billing/Models/ProviderSubscriptionData.cs deleted file mode 100644 index 27da6cd226..0000000000 --- a/src/Core/Billing/Models/ProviderSubscriptionData.cs +++ /dev/null @@ -1,7 +0,0 @@ -using Stripe; - -namespace Bit.Core.Billing.Models; - -public record ProviderSubscriptionData( - List ProviderPlans, - Subscription Subscription); diff --git a/src/Core/Billing/Queries/IGetSubscriptionQuery.cs b/src/Core/Billing/Queries/IGetSubscriptionQuery.cs new file mode 100644 index 0000000000..9ba2a85ed5 --- /dev/null +++ b/src/Core/Billing/Queries/IGetSubscriptionQuery.cs @@ -0,0 +1,18 @@ +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Stripe; + +namespace Bit.Core.Billing.Queries; + +public interface IGetSubscriptionQuery +{ + /// + /// Retrieves a Stripe using the 's property. + /// + /// The organization or user to retrieve the subscription for. + /// A Stripe . + /// Thrown when the is . + /// Thrown when the subscriber's is or empty. + /// Thrown when the returned from Stripe's API is null. + Task GetSubscription(ISubscriber subscriber); +} diff --git a/src/Core/Billing/Queries/IProviderBillingQueries.cs b/src/Core/Billing/Queries/IProviderBillingQueries.cs deleted file mode 100644 index e4b7d0f14d..0000000000 --- a/src/Core/Billing/Queries/IProviderBillingQueries.cs +++ /dev/null @@ -1,27 +0,0 @@ -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.Billing.Models; -using Bit.Core.Enums; - -namespace Bit.Core.Billing.Queries; - -public interface IProviderBillingQueries -{ - /// - /// Retrieves the number of seats an MSP has assigned to its client organizations with a specified . - /// - /// The ID of the MSP to retrieve the assigned seat total for. - /// The type of plan to retrieve the assigned seat total for. - /// An representing the number of seats the provider has assigned to its client organizations with the specified . - /// Thrown when the provider represented by the is . - /// Thrown when the provider represented by the has . - Task GetAssignedSeatTotalForPlanOrThrow(Guid providerId, PlanType planType); - - /// - /// Retrieves a provider's billing subscription data. - /// - /// The ID of the provider to retrieve subscription data for. - /// A object containing the provider's Stripe and their s. - /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task GetSubscriptionData(Guid providerId); -} diff --git a/src/Core/Billing/Queries/ISubscriberQueries.cs b/src/Core/Billing/Queries/ISubscriberQueries.cs deleted file mode 100644 index ea6c0d985e..0000000000 --- a/src/Core/Billing/Queries/ISubscriberQueries.cs +++ /dev/null @@ -1,30 +0,0 @@ -using Bit.Core.Entities; -using Bit.Core.Exceptions; -using Stripe; - -namespace Bit.Core.Billing.Queries; - -public interface ISubscriberQueries -{ - /// - /// Retrieves a Stripe using the 's property. - /// - /// The organization, provider or user to retrieve the subscription for. - /// Optional parameters that can be passed to Stripe to expand or modify the . - /// A Stripe . - /// Thrown when the is . - /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task GetSubscription( - ISubscriber subscriber, - SubscriptionGetOptions subscriptionGetOptions = null); - - /// - /// Retrieves a Stripe using the 's property. - /// - /// The organization or user to retrieve the subscription for. - /// A Stripe . - /// Thrown when the is . - /// Thrown when the subscriber's is or empty. - /// Thrown when the returned from Stripe's API is null. - Task GetSubscriptionOrThrow(ISubscriber subscriber); -} diff --git a/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs b/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs new file mode 100644 index 0000000000..c3b0a29552 --- /dev/null +++ b/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs @@ -0,0 +1,36 @@ +using Bit.Core.Entities; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using Stripe; + +using static Bit.Core.Billing.Utilities; + +namespace Bit.Core.Billing.Queries.Implementations; + +public class GetSubscriptionQuery( + ILogger logger, + IStripeAdapter stripeAdapter) : IGetSubscriptionQuery +{ + public async Task GetSubscription(ISubscriber subscriber) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); + + throw ContactSupport(); + } + + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + + if (subscription != null) + { + return subscription; + } + + logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); + + throw ContactSupport(); + } +} diff --git a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs deleted file mode 100644 index f8bff9d3fd..0000000000 --- a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs +++ /dev/null @@ -1,92 +0,0 @@ -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Repositories; -using Bit.Core.Enums; -using Bit.Core.Utilities; -using Microsoft.Extensions.Logging; -using Stripe; -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Queries.Implementations; - -public class ProviderBillingQueries( - ILogger logger, - IProviderOrganizationRepository providerOrganizationRepository, - IProviderPlanRepository providerPlanRepository, - IProviderRepository providerRepository, - ISubscriberQueries subscriberQueries) : IProviderBillingQueries -{ - public async Task GetAssignedSeatTotalForPlanOrThrow( - Guid providerId, - PlanType planType) - { - var provider = await providerRepository.GetByIdAsync(providerId); - - if (provider == null) - { - logger.LogError( - "Could not find provider ({ID}) when retrieving assigned seat total", - providerId); - - throw ContactSupport(); - } - - if (provider.Type == ProviderType.Reseller) - { - logger.LogError("Assigned seats cannot be retrieved for reseller-type provider ({ID})", providerId); - - throw ContactSupport("Consolidated billing does not support reseller-type providers"); - } - - var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); - - var plan = StaticStore.GetPlan(planType); - - return providerOrganizations - .Where(providerOrganization => providerOrganization.Plan == plan.Name) - .Sum(providerOrganization => providerOrganization.Seats ?? 0); - } - - public async Task GetSubscriptionData(Guid providerId) - { - var provider = await providerRepository.GetByIdAsync(providerId); - - if (provider == null) - { - logger.LogError( - "Could not find provider ({ID}) when retrieving subscription data.", - providerId); - - return null; - } - - if (provider.Type == ProviderType.Reseller) - { - logger.LogError("Subscription data cannot be retrieved for reseller-type provider ({ID})", providerId); - - throw ContactSupport("Consolidated billing does not support reseller-type providers"); - } - - var subscription = await subscriberQueries.GetSubscription(provider, new SubscriptionGetOptions - { - Expand = ["customer"] - }); - - if (subscription == null) - { - return null; - } - - var providerPlans = await providerPlanRepository.GetByProviderId(providerId); - - var configuredProviderPlans = providerPlans - .Where(providerPlan => providerPlan.IsConfigured()) - .Select(ConfiguredProviderPlan.From) - .ToList(); - - return new ProviderSubscriptionData( - configuredProviderPlans, - subscription); - } -} diff --git a/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs b/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs deleted file mode 100644 index a160a87595..0000000000 --- a/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs +++ /dev/null @@ -1,61 +0,0 @@ -using Bit.Core.Entities; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; - -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Queries.Implementations; - -public class SubscriberQueries( - ILogger logger, - IStripeAdapter stripeAdapter) : ISubscriberQueries -{ - public async Task GetSubscription( - ISubscriber subscriber, - SubscriptionGetOptions subscriptionGetOptions = null) - { - ArgumentNullException.ThrowIfNull(subscriber); - - if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) - { - logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); - - return null; - } - - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); - - if (subscription != null) - { - return subscription; - } - - logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); - - return null; - } - - public async Task GetSubscriptionOrThrow(ISubscriber subscriber) - { - ArgumentNullException.ThrowIfNull(subscriber); - - if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) - { - logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id); - - throw ContactSupport(); - } - - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - - if (subscription != null) - { - return subscription; - } - - logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId); - - throw ContactSupport(); - } -} diff --git a/src/Core/Billing/Repositories/IProviderPlanRepository.cs b/src/Core/Billing/Repositories/IProviderPlanRepository.cs index eccbad82bb..ccfc6ee683 100644 --- a/src/Core/Billing/Repositories/IProviderPlanRepository.cs +++ b/src/Core/Billing/Repositories/IProviderPlanRepository.cs @@ -5,5 +5,5 @@ namespace Bit.Core.Billing.Repositories; public interface IProviderPlanRepository : IRepository { - Task> GetByProviderId(Guid providerId); + Task GetByProviderId(Guid providerId); } diff --git a/src/Core/Billing/Utilities.cs b/src/Core/Billing/Utilities.cs index 2b06f1ea6c..54ace07a70 100644 --- a/src/Core/Billing/Utilities.cs +++ b/src/Core/Billing/Utilities.cs @@ -1,11 +1,8 @@ -namespace Bit.Core.Billing; +using Bit.Core.Exceptions; + +namespace Bit.Core.Billing; public static class Utilities { - public const string BraintreeCustomerIdKey = "btCustomerId"; - - public static BillingException ContactSupport( - string internalMessage = null, - Exception innerException = null) => new("Something went wrong with your request. Please contact support.", - internalMessage, innerException); + public static GatewayException ContactSupport() => new("Something went wrong with your request. Please contact support."); } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 6edca0c505..e7685891ad 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -114,6 +114,7 @@ public static class FeatureFlagKeys /// public const string FlexibleCollections = "flexible-collections-disabled-do-not-use"; public const string FlexibleCollectionsV1 = "flexible-collections-v-1"; // v-1 is intentional + public const string BulkCollectionAccess = "bulk-collection-access"; public const string ItemShare = "item-share"; public const string KeyRotationImprovements = "key-rotation-improvements"; public const string DuoRedirect = "duo-redirect"; @@ -130,8 +131,6 @@ public static class FeatureFlagKeys public const string PM5864DollarThreshold = "PM-5864-dollar-threshold"; public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email"; public const string ShowPaymentMethodWarningBanners = "show-payment-method-warning-banners"; - public const string EnableConsolidatedBilling = "enable-consolidated-billing"; - public const string AC1795_UpdatedSubscriptionStatusSection = "AC-1795_updated-subscription-status-section"; public static List GetAllKeys() { diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index 3e77b5d105..ff3c632b5b 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -21,9 +21,8 @@ - - - + + @@ -36,8 +35,9 @@ + - + @@ -50,10 +50,10 @@ - + - + diff --git a/src/Core/Models/Business/CompleteSubscriptionUpdate.cs b/src/Core/Models/Business/CompleteSubscriptionUpdate.cs index aa1c92dc2e..a1146cd2a0 100644 --- a/src/Core/Models/Business/CompleteSubscriptionUpdate.cs +++ b/src/Core/Models/Business/CompleteSubscriptionUpdate.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; using Bit.Core.Exceptions; using Stripe; @@ -278,6 +279,25 @@ public class CompleteSubscriptionUpdate : SubscriptionUpdate }; } + private static SubscriptionItem FindSubscriptionItem(Subscription subscription, string planId) + { + if (string.IsNullOrEmpty(planId)) + { + return null; + } + + var data = subscription.Items.Data; + + var subscriptionItem = data.FirstOrDefault(item => item.Plan?.Id == planId) ?? data.FirstOrDefault(item => item.Price?.Id == planId); + + return subscriptionItem; + } + + private static string GetPasswordManagerPlanId(StaticStore.Plan plan) + => IsNonSeatBasedPlan(plan) + ? plan.PasswordManager.StripePlanId + : plan.PasswordManager.StripeSeatPlanId; + private static SubscriptionData GetSubscriptionDataFor(Organization organization) { var plan = Utilities.StaticStore.GetPlan(organization.PlanType); @@ -300,4 +320,10 @@ public class CompleteSubscriptionUpdate : SubscriptionUpdate 0 }; } + + private static bool IsNonSeatBasedPlan(StaticStore.Plan plan) + => plan.Type is + >= PlanType.FamiliesAnnually2019 and <= PlanType.EnterpriseAnnually2019 + or PlanType.FamiliesAnnually + or PlanType.TeamsStarter; } diff --git a/src/Core/Models/Business/ProviderSubscriptionUpdate.cs b/src/Core/Models/Business/ProviderSubscriptionUpdate.cs deleted file mode 100644 index 8b29bebce5..0000000000 --- a/src/Core/Models/Business/ProviderSubscriptionUpdate.cs +++ /dev/null @@ -1,61 +0,0 @@ -using Bit.Core.Billing.Extensions; -using Bit.Core.Enums; -using Stripe; - -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Models.Business; - -public class ProviderSubscriptionUpdate : SubscriptionUpdate -{ - private readonly string _planId; - private readonly int _previouslyPurchasedSeats; - private readonly int _newlyPurchasedSeats; - - protected override List PlanIds => [_planId]; - - public ProviderSubscriptionUpdate( - PlanType planType, - int previouslyPurchasedSeats, - int newlyPurchasedSeats) - { - if (!planType.SupportsConsolidatedBilling()) - { - throw ContactSupport($"Cannot create a {nameof(ProviderSubscriptionUpdate)} for {nameof(PlanType)} that doesn't support consolidated billing"); - } - - _planId = GetPasswordManagerPlanId(Utilities.StaticStore.GetPlan(planType)); - _previouslyPurchasedSeats = previouslyPurchasedSeats; - _newlyPurchasedSeats = newlyPurchasedSeats; - } - - public override List RevertItemsOptions(Subscription subscription) - { - var subscriptionItem = FindSubscriptionItem(subscription, _planId); - - return - [ - new SubscriptionItemOptions - { - Id = subscriptionItem.Id, - Price = _planId, - Quantity = _previouslyPurchasedSeats - } - ]; - } - - public override List UpgradeItemsOptions(Subscription subscription) - { - var subscriptionItem = FindSubscriptionItem(subscription, _planId); - - return - [ - new SubscriptionItemOptions - { - Id = subscriptionItem.Id, - Price = _planId, - Quantity = _newlyPurchasedSeats - } - ]; - } -} diff --git a/src/Core/Models/Business/SeatSubscriptionUpdate.cs b/src/Core/Models/Business/SeatSubscriptionUpdate.cs index db5104ddd2..c5ea1a7474 100644 --- a/src/Core/Models/Business/SeatSubscriptionUpdate.cs +++ b/src/Core/Models/Business/SeatSubscriptionUpdate.cs @@ -18,7 +18,7 @@ public class SeatSubscriptionUpdate : SubscriptionUpdate public override List UpgradeItemsOptions(Subscription subscription) { - var item = FindSubscriptionItem(subscription, PlanIds.Single()); + var item = SubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions @@ -34,7 +34,7 @@ public class SeatSubscriptionUpdate : SubscriptionUpdate public override List RevertItemsOptions(Subscription subscription) { - var item = FindSubscriptionItem(subscription, PlanIds.Single()); + var item = SubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions diff --git a/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs b/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs index c3e3e09992..c93212eac8 100644 --- a/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs +++ b/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs @@ -19,7 +19,7 @@ public class ServiceAccountSubscriptionUpdate : SubscriptionUpdate public override List UpgradeItemsOptions(Subscription subscription) { - var item = FindSubscriptionItem(subscription, PlanIds.Single()); + var item = SubscriptionItem(subscription, PlanIds.Single()); _prevServiceAccounts = item?.Quantity ?? 0; return new() { @@ -35,7 +35,7 @@ public class ServiceAccountSubscriptionUpdate : SubscriptionUpdate public override List RevertItemsOptions(Subscription subscription) { - var item = FindSubscriptionItem(subscription, PlanIds.Single()); + var item = SubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions diff --git a/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs b/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs index b8201b9775..ff6bb55011 100644 --- a/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs +++ b/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs @@ -19,7 +19,7 @@ public class SmSeatSubscriptionUpdate : SubscriptionUpdate public override List UpgradeItemsOptions(Subscription subscription) { - var item = FindSubscriptionItem(subscription, PlanIds.Single()); + var item = SubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions @@ -35,7 +35,7 @@ public class SmSeatSubscriptionUpdate : SubscriptionUpdate public override List RevertItemsOptions(Subscription subscription) { - var item = FindSubscriptionItem(subscription, PlanIds.Single()); + var item = SubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions diff --git a/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs b/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs index 59a745297b..88af72f199 100644 --- a/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs +++ b/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs @@ -74,10 +74,10 @@ public class SponsorOrganizationSubscriptionUpdate : SubscriptionUpdate private string AddStripePlanId => _applySponsorship ? _sponsoredPlanStripeId : _existingPlanStripeId; private Stripe.SubscriptionItem RemoveStripeItem(Subscription subscription) => _applySponsorship ? - FindSubscriptionItem(subscription, _existingPlanStripeId) : - FindSubscriptionItem(subscription, _sponsoredPlanStripeId); + SubscriptionItem(subscription, _existingPlanStripeId) : + SubscriptionItem(subscription, _sponsoredPlanStripeId); private Stripe.SubscriptionItem AddStripeItem(Subscription subscription) => _applySponsorship ? - FindSubscriptionItem(subscription, _sponsoredPlanStripeId) : - FindSubscriptionItem(subscription, _existingPlanStripeId); + SubscriptionItem(subscription, _sponsoredPlanStripeId) : + SubscriptionItem(subscription, _existingPlanStripeId); } diff --git a/src/Core/Models/Business/StorageSubscriptionUpdate.cs b/src/Core/Models/Business/StorageSubscriptionUpdate.cs index b0f4a83d3e..30ab2428e2 100644 --- a/src/Core/Models/Business/StorageSubscriptionUpdate.cs +++ b/src/Core/Models/Business/StorageSubscriptionUpdate.cs @@ -17,7 +17,7 @@ public class StorageSubscriptionUpdate : SubscriptionUpdate public override List UpgradeItemsOptions(Subscription subscription) { - var item = FindSubscriptionItem(subscription, PlanIds.Single()); + var item = SubscriptionItem(subscription, PlanIds.Single()); _prevStorage = item?.Quantity ?? 0; return new() { @@ -38,7 +38,7 @@ public class StorageSubscriptionUpdate : SubscriptionUpdate throw new Exception("Unknown previous value, must first call UpgradeItemsOptions"); } - var item = FindSubscriptionItem(subscription, PlanIds.Single()); + var item = SubscriptionItem(subscription, PlanIds.Single()); return new() { new SubscriptionItemOptions diff --git a/src/Core/Models/Business/SubscriptionInfo.cs b/src/Core/Models/Business/SubscriptionInfo.cs index 7bb5bddbc8..23f8f95278 100644 --- a/src/Core/Models/Business/SubscriptionInfo.cs +++ b/src/Core/Models/Business/SubscriptionInfo.cs @@ -43,9 +43,6 @@ public class SubscriptionInfo Items = sub.Items.Data.Select(i => new BillingSubscriptionItem(i)); } CollectionMethod = sub.CollectionMethod; - GracePeriod = sub.CollectionMethod == "charge_automatically" - ? 14 - : 30; } public DateTime? TrialStartDate { get; set; } @@ -59,9 +56,6 @@ public class SubscriptionInfo public bool Cancelled { get; set; } public IEnumerable Items { get; set; } = new List(); public string CollectionMethod { get; set; } - public DateTime? SuspensionDate { get; set; } - public DateTime? UnpaidPeriodEndDate { get; set; } - public int GracePeriod { get; set; } public class BillingSubscriptionItem { diff --git a/src/Core/Models/Business/SubscriptionUpdate.cs b/src/Core/Models/Business/SubscriptionUpdate.cs index bba9d384d2..70106a10ea 100644 --- a/src/Core/Models/Business/SubscriptionUpdate.cs +++ b/src/Core/Models/Business/SubscriptionUpdate.cs @@ -1,5 +1,4 @@ -using Bit.Core.Enums; -using Stripe; +using Stripe; namespace Bit.Core.Models.Business; @@ -16,7 +15,7 @@ public abstract class SubscriptionUpdate foreach (var upgradeItemOptions in upgradeItemsOptions) { var upgradeQuantity = upgradeItemOptions.Quantity ?? 0; - var existingQuantity = FindSubscriptionItem(subscription, upgradeItemOptions.Plan)?.Quantity ?? 0; + var existingQuantity = SubscriptionItem(subscription, upgradeItemOptions.Plan)?.Quantity ?? 0; if (upgradeQuantity != existingQuantity) { return true; @@ -25,28 +24,6 @@ public abstract class SubscriptionUpdate return false; } - protected static SubscriptionItem FindSubscriptionItem(Subscription subscription, string planId) - { - if (string.IsNullOrEmpty(planId)) - { - return null; - } - - var data = subscription.Items.Data; - - var subscriptionItem = data.FirstOrDefault(item => item.Plan?.Id == planId) ?? data.FirstOrDefault(item => item.Price?.Id == planId); - - return subscriptionItem; - } - - protected static string GetPasswordManagerPlanId(StaticStore.Plan plan) - => IsNonSeatBasedPlan(plan) - ? plan.PasswordManager.StripePlanId - : plan.PasswordManager.StripeSeatPlanId; - - protected static bool IsNonSeatBasedPlan(StaticStore.Plan plan) - => plan.Type is - >= PlanType.FamiliesAnnually2019 and <= PlanType.EnterpriseAnnually2019 - or PlanType.FamiliesAnnually - or PlanType.TeamsStarter; + protected static SubscriptionItem SubscriptionItem(Subscription subscription, string planId) => + planId == null ? null : subscription.Items?.Data?.FirstOrDefault(i => i.Plan.Id == planId); } diff --git a/src/Core/Models/Data/DictionaryEntity.cs b/src/Core/Models/Data/DictionaryEntity.cs new file mode 100644 index 0000000000..72e6c871c7 --- /dev/null +++ b/src/Core/Models/Data/DictionaryEntity.cs @@ -0,0 +1,134 @@ +using System.Collections; +using Microsoft.Azure.Cosmos.Table; + +namespace Bit.Core.Models.Data; + +public class DictionaryEntity : TableEntity, IDictionary +{ + private IDictionary _properties = new Dictionary(); + + public ICollection Values => _properties.Values; + + public EntityProperty this[string key] + { + get => _properties[key]; + set => _properties[key] = value; + } + + public int Count => _properties.Count; + + public bool IsReadOnly => _properties.IsReadOnly; + + public ICollection Keys => _properties.Keys; + + public override void ReadEntity(IDictionary properties, + OperationContext operationContext) + { + _properties = properties; + } + + public override IDictionary WriteEntity(OperationContext operationContext) + { + return _properties; + } + + public void Add(string key, EntityProperty value) + { + _properties.Add(key, value); + } + + public void Add(string key, bool value) + { + _properties.Add(key, new EntityProperty(value)); + } + + public void Add(string key, byte[] value) + { + _properties.Add(key, new EntityProperty(value)); + } + + public void Add(string key, DateTime? value) + { + _properties.Add(key, new EntityProperty(value)); + } + + public void Add(string key, DateTimeOffset? value) + { + _properties.Add(key, new EntityProperty(value)); + } + + public void Add(string key, double value) + { + _properties.Add(key, new EntityProperty(value)); + } + + public void Add(string key, Guid value) + { + _properties.Add(key, new EntityProperty(value)); + } + + public void Add(string key, int value) + { + _properties.Add(key, new EntityProperty(value)); + } + + public void Add(string key, long value) + { + _properties.Add(key, new EntityProperty(value)); + } + + public void Add(string key, string value) + { + _properties.Add(key, new EntityProperty(value)); + } + + public void Add(KeyValuePair item) + { + _properties.Add(item); + } + + public bool ContainsKey(string key) + { + return _properties.ContainsKey(key); + } + + public bool Remove(string key) + { + return _properties.Remove(key); + } + + public bool TryGetValue(string key, out EntityProperty value) + { + return _properties.TryGetValue(key, out value); + } + + public void Clear() + { + _properties.Clear(); + } + + public bool Contains(KeyValuePair item) + { + return _properties.Contains(item); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + _properties.CopyTo(array, arrayIndex); + } + + public bool Remove(KeyValuePair item) + { + return _properties.Remove(item); + } + + public IEnumerator> GetEnumerator() + { + return _properties.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _properties.GetEnumerator(); + } +} diff --git a/src/Core/Models/Data/EventTableEntity.cs b/src/Core/Models/Data/EventTableEntity.cs index 69365f4127..df4a85acaf 100644 --- a/src/Core/Models/Data/EventTableEntity.cs +++ b/src/Core/Models/Data/EventTableEntity.cs @@ -1,73 +1,10 @@ -using Azure; -using Azure.Data.Tables; -using Bit.Core.Enums; +using Bit.Core.Enums; using Bit.Core.Utilities; +using Microsoft.Azure.Cosmos.Table; namespace Bit.Core.Models.Data; -// used solely for interaction with Azure Table Storage -public class AzureEvent : ITableEntity -{ - public string PartitionKey { get; set; } - public string RowKey { get; set; } - public DateTimeOffset? Timestamp { get; set; } - public ETag ETag { get; set; } - - public DateTime Date { get; set; } - public int Type { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Guid? InstallationId { get; set; } - public Guid? ProviderId { get; set; } - public Guid? CipherId { get; set; } - public Guid? CollectionId { get; set; } - public Guid? PolicyId { get; set; } - public Guid? GroupId { get; set; } - public Guid? OrganizationUserId { get; set; } - public Guid? ProviderUserId { get; set; } - public Guid? ProviderOrganizationId { get; set; } - public int? DeviceType { get; set; } - public string IpAddress { get; set; } - public Guid? ActingUserId { get; set; } - public int? SystemUser { get; set; } - public string DomainName { get; set; } - public Guid? SecretId { get; set; } - public Guid? ServiceAccountId { get; set; } - - public EventTableEntity ToEventTableEntity() - { - return new EventTableEntity - { - PartitionKey = PartitionKey, - RowKey = RowKey, - Timestamp = Timestamp, - ETag = ETag, - - Date = Date, - Type = (EventType)Type, - UserId = UserId, - OrganizationId = OrganizationId, - InstallationId = InstallationId, - ProviderId = ProviderId, - CipherId = CipherId, - CollectionId = CollectionId, - PolicyId = PolicyId, - GroupId = GroupId, - OrganizationUserId = OrganizationUserId, - ProviderUserId = ProviderUserId, - ProviderOrganizationId = ProviderOrganizationId, - DeviceType = DeviceType.HasValue ? (DeviceType)DeviceType.Value : null, - IpAddress = IpAddress, - ActingUserId = ActingUserId, - SystemUser = SystemUser.HasValue ? (EventSystemUser)SystemUser.Value : null, - DomainName = DomainName, - SecretId = SecretId, - ServiceAccountId = ServiceAccountId - }; - } -} - -public class EventTableEntity : IEvent +public class EventTableEntity : TableEntity, IEvent { public EventTableEntity() { } @@ -95,11 +32,6 @@ public class EventTableEntity : IEvent ServiceAccountId = e.ServiceAccountId; } - public string PartitionKey { get; set; } - public string RowKey { get; set; } - public DateTimeOffset? Timestamp { get; set; } - public ETag ETag { get; set; } - public DateTime Date { get; set; } public EventType Type { get; set; } public Guid? UserId { get; set; } @@ -121,36 +53,65 @@ public class EventTableEntity : IEvent public Guid? SecretId { get; set; } public Guid? ServiceAccountId { get; set; } - public AzureEvent ToAzureEvent() + public override IDictionary WriteEntity(OperationContext operationContext) { - return new AzureEvent - { - PartitionKey = PartitionKey, - RowKey = RowKey, - Timestamp = Timestamp, - ETag = ETag, + var result = base.WriteEntity(operationContext); - Date = Date, - Type = (int)Type, - UserId = UserId, - OrganizationId = OrganizationId, - InstallationId = InstallationId, - ProviderId = ProviderId, - CipherId = CipherId, - CollectionId = CollectionId, - PolicyId = PolicyId, - GroupId = GroupId, - OrganizationUserId = OrganizationUserId, - ProviderUserId = ProviderUserId, - ProviderOrganizationId = ProviderOrganizationId, - DeviceType = DeviceType.HasValue ? (int)DeviceType.Value : null, - IpAddress = IpAddress, - ActingUserId = ActingUserId, - SystemUser = SystemUser.HasValue ? (int)SystemUser.Value : null, - DomainName = DomainName, - SecretId = SecretId, - ServiceAccountId = ServiceAccountId - }; + var typeName = nameof(Type); + if (result.ContainsKey(typeName)) + { + result[typeName] = new EntityProperty((int)Type); + } + else + { + result.Add(typeName, new EntityProperty((int)Type)); + } + + var deviceTypeName = nameof(DeviceType); + if (result.ContainsKey(deviceTypeName)) + { + result[deviceTypeName] = new EntityProperty((int?)DeviceType); + } + else + { + result.Add(deviceTypeName, new EntityProperty((int?)DeviceType)); + } + + var systemUserTypeName = nameof(SystemUser); + if (result.ContainsKey(systemUserTypeName)) + { + result[systemUserTypeName] = new EntityProperty((int?)SystemUser); + } + else + { + result.Add(systemUserTypeName, new EntityProperty((int?)SystemUser)); + } + + return result; + } + + public override void ReadEntity(IDictionary properties, + OperationContext operationContext) + { + base.ReadEntity(properties, operationContext); + + var typeName = nameof(Type); + if (properties.ContainsKey(typeName) && properties[typeName].Int32Value.HasValue) + { + Type = (EventType)properties[typeName].Int32Value.Value; + } + + var deviceTypeName = nameof(DeviceType); + if (properties.ContainsKey(deviceTypeName) && properties[deviceTypeName].Int32Value.HasValue) + { + DeviceType = (DeviceType)properties[deviceTypeName].Int32Value.Value; + } + + var systemUserTypeName = nameof(SystemUser); + if (properties.ContainsKey(systemUserTypeName) && properties[systemUserTypeName].Int32Value.HasValue) + { + SystemUser = (EventSystemUser)properties[systemUserTypeName].Int32Value.Value; + } } public static List IndexEvent(EventMessage e) diff --git a/src/Core/Models/Data/InstallationDeviceEntity.cs b/src/Core/Models/Data/InstallationDeviceEntity.cs index 3186efc661..cb7bf00873 100644 --- a/src/Core/Models/Data/InstallationDeviceEntity.cs +++ b/src/Core/Models/Data/InstallationDeviceEntity.cs @@ -1,9 +1,8 @@ -using Azure; -using Azure.Data.Tables; +using Microsoft.Azure.Cosmos.Table; namespace Bit.Core.Models.Data; -public class InstallationDeviceEntity : ITableEntity +public class InstallationDeviceEntity : TableEntity { public InstallationDeviceEntity() { } @@ -28,11 +27,6 @@ public class InstallationDeviceEntity : ITableEntity RowKey = parts[1]; } - public string PartitionKey { get; set; } - public string RowKey { get; set; } - public DateTimeOffset? Timestamp { get; set; } - public ETag ETag { get; set; } - public static bool IsInstallationDeviceId(string deviceId) { return deviceId != null && deviceId.Length == 73 && deviceId[36] == '_'; diff --git a/src/Core/Repositories/TableStorage/EventRepository.cs b/src/Core/Repositories/TableStorage/EventRepository.cs index 7c5cb97dba..7044850033 100644 --- a/src/Core/Repositories/TableStorage/EventRepository.cs +++ b/src/Core/Repositories/TableStorage/EventRepository.cs @@ -1,14 +1,14 @@ -using Azure.Data.Tables; -using Bit.Core.Models.Data; +using Bit.Core.Models.Data; using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.Core.Vault.Entities; +using Microsoft.Azure.Cosmos.Table; namespace Bit.Core.Repositories.TableStorage; public class EventRepository : IEventRepository { - private readonly TableClient _tableClient; + private readonly CloudTable _table; public EventRepository(GlobalSettings globalSettings) : this(globalSettings.Events.ConnectionString) @@ -16,8 +16,9 @@ public class EventRepository : IEventRepository public EventRepository(string storageConnectionString) { - var tableClient = new TableServiceClient(storageConnectionString); - _tableClient = tableClient.GetTableClient("event"); + var storageAccount = CloudStorageAccount.Parse(storageConnectionString); + var tableClient = storageAccount.CreateCloudTableClient(); + _table = tableClient.GetTableReference("event"); } public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, @@ -75,7 +76,7 @@ public class EventRepository : IEventRepository throw new ArgumentException(nameof(e)); } - await CreateEventAsync(entity); + await CreateEntityAsync(entity); } public async Task CreateManyAsync(IEnumerable e) @@ -98,7 +99,7 @@ public class EventRepository : IEventRepository var groupEntities = group.ToList(); if (groupEntities.Count == 1) { - await CreateEventAsync(groupEntities.First()); + await CreateEntityAsync(groupEntities.First()); continue; } @@ -106,7 +107,7 @@ public class EventRepository : IEventRepository var iterations = groupEntities.Count / 100; for (var i = 0; i <= iterations; i++) { - var batch = new List(); + var batch = new TableBatchOperation(); var batchEntities = groupEntities.Skip(i * 100).Take(100); if (!batchEntities.Any()) { @@ -115,15 +116,19 @@ public class EventRepository : IEventRepository foreach (var entity in batchEntities) { - batch.Add(new TableTransactionAction(TableTransactionActionType.Add, - entity.ToAzureEvent())); + batch.InsertOrReplace(entity); } - await _tableClient.SubmitTransactionAsync(batch); + await _table.ExecuteBatchAsync(batch); } } } + public async Task CreateEntityAsync(ITableEntity entity) + { + await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); + } + public async Task> GetManyAsync(string partitionKey, string rowKey, DateTime startDate, DateTime endDate, PageOptions pageOptions) { @@ -131,28 +136,60 @@ public class EventRepository : IEventRepository var end = CoreHelpers.DateTimeToTableStorageKey(endDate); var filter = MakeFilter(partitionKey, string.Format(rowKey, start), string.Format(rowKey, end)); + var query = new TableQuery().Where(filter).Take(pageOptions.PageSize); var result = new PagedResult(); - var query = _tableClient.QueryAsync(filter, pageOptions.PageSize); + var continuationToken = DeserializeContinuationToken(pageOptions?.ContinuationToken); - await using (var enumerator = query.AsPages(pageOptions?.ContinuationToken, - pageOptions.PageSize).GetAsyncEnumerator()) - { - await enumerator.MoveNextAsync(); - - result.ContinuationToken = enumerator.Current.ContinuationToken; - result.Data.AddRange(enumerator.Current.Values.Select(e => e.ToEventTableEntity())); - } + var queryResults = await _table.ExecuteQuerySegmentedAsync(query, continuationToken); + result.ContinuationToken = SerializeContinuationToken(queryResults.ContinuationToken); + result.Data.AddRange(queryResults.Results); return result; } - private async Task CreateEventAsync(EventTableEntity entity) - { - await _tableClient.UpsertEntityAsync(entity.ToAzureEvent()); - } - private string MakeFilter(string partitionKey, string rowStart, string rowEnd) { - return $"PartitionKey eq '{partitionKey}' and RowKey le '{rowStart}' and RowKey ge '{rowEnd}'"; + var rowFilter = TableQuery.CombineFilters( + TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.LessThanOrEqual, $"{rowStart}`"), + TableOperators.And, + TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.GreaterThanOrEqual, $"{rowEnd}_")); + + return TableQuery.CombineFilters( + TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, partitionKey), + TableOperators.And, + rowFilter); + } + + private string SerializeContinuationToken(TableContinuationToken token) + { + if (token == null) + { + return null; + } + + return string.Format("{0}__{1}__{2}__{3}", (int)token.TargetLocation, token.NextTableName, + token.NextPartitionKey, token.NextRowKey); + } + + private TableContinuationToken DeserializeContinuationToken(string token) + { + if (string.IsNullOrWhiteSpace(token)) + { + return null; + } + + var tokenParts = token.Split(new string[] { "__" }, StringSplitOptions.None); + if (tokenParts.Length < 4 || !Enum.TryParse(tokenParts[0], out StorageLocation tLoc)) + { + return null; + } + + return new TableContinuationToken + { + TargetLocation = tLoc, + NextTableName = string.IsNullOrWhiteSpace(tokenParts[1]) ? null : tokenParts[1], + NextPartitionKey = string.IsNullOrWhiteSpace(tokenParts[2]) ? null : tokenParts[2], + NextRowKey = string.IsNullOrWhiteSpace(tokenParts[3]) ? null : tokenParts[3] + }; } } diff --git a/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs b/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs index 2dee07dc2b..32b466d1b3 100644 --- a/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs +++ b/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs @@ -1,12 +1,13 @@ -using Azure.Data.Tables; +using System.Net; using Bit.Core.Models.Data; using Bit.Core.Settings; +using Microsoft.Azure.Cosmos.Table; namespace Bit.Core.Repositories.TableStorage; public class InstallationDeviceRepository : IInstallationDeviceRepository { - private readonly TableClient _tableClient; + private readonly CloudTable _table; public InstallationDeviceRepository(GlobalSettings globalSettings) : this(globalSettings.Events.ConnectionString) @@ -14,13 +15,14 @@ public class InstallationDeviceRepository : IInstallationDeviceRepository public InstallationDeviceRepository(string storageConnectionString) { - var tableClient = new TableServiceClient(storageConnectionString); - _tableClient = tableClient.GetTableClient("installationdevice"); + var storageAccount = CloudStorageAccount.Parse(storageConnectionString); + var tableClient = storageAccount.CreateCloudTableClient(); + _table = tableClient.GetTableReference("installationdevice"); } public async Task UpsertAsync(InstallationDeviceEntity entity) { - await _tableClient.UpsertEntityAsync(entity); + await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); } public async Task UpsertManyAsync(IList entities) @@ -50,7 +52,7 @@ public class InstallationDeviceRepository : IInstallationDeviceRepository var iterations = groupEntities.Count / 100; for (var i = 0; i <= iterations; i++) { - var batch = new List(); + var batch = new TableBatchOperation(); var batchEntities = groupEntities.Skip(i * 100).Take(100); if (!batchEntities.Any()) { @@ -59,16 +61,24 @@ public class InstallationDeviceRepository : IInstallationDeviceRepository foreach (var entity in batchEntities) { - batch.Add(new TableTransactionAction(TableTransactionActionType.UpsertReplace, entity)); + batch.InsertOrReplace(entity); } - await _tableClient.SubmitTransactionAsync(batch); + await _table.ExecuteBatchAsync(batch); } } } public async Task DeleteAsync(InstallationDeviceEntity entity) { - await _tableClient.DeleteEntityAsync(entity.PartitionKey, entity.RowKey); + try + { + entity.ETag = "*"; + await _table.ExecuteAsync(TableOperation.Delete(entity)); + } + catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound) + { + throw; + } } } diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index e0d2e95dc9..f8f24cfbdb 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -1,5 +1,4 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; @@ -29,12 +28,6 @@ public interface IPaymentService int newlyPurchasedAdditionalStorage, DateTime? prorationDate = null); Task AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats, DateTime? prorationDate = null); - Task AdjustSeats( - Provider provider, - Plan plan, - int currentlySubscribedSeats, - int newlySubscribedSeats, - DateTime? prorationDate = null); Task AdjustSmSeatsAsync(Organization organization, Plan plan, int additionalSeats, DateTime? prorationDate = null); Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId, DateTime? prorationDate = null); diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs index 908dc2c0d8..073d5cdacd 100644 --- a/src/Core/Services/IStripeAdapter.cs +++ b/src/Core/Services/IStripeAdapter.cs @@ -1,5 +1,4 @@ using Bit.Core.Models.BitStripe; -using Stripe; namespace Bit.Core.Services; @@ -17,7 +16,6 @@ public interface IStripeAdapter Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options); Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options); Task> InvoiceListAsync(StripeInvoiceListOptions options); - Task> InvoiceSearchAsync(InvoiceSearchOptions options); Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options); Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options); Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options); diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs index a7109252d4..ef8d13aea8 100644 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Services/Implementations/StripeAdapter.cs @@ -1,5 +1,4 @@ using Bit.Core.Models.BitStripe; -using Stripe; namespace Bit.Core.Services; @@ -104,9 +103,6 @@ public class StripeAdapter : IStripeAdapter return invoices; } - public async Task> InvoiceSearchAsync(InvoiceSearchOptions options) - => (await _invoiceService.SearchAsync(options)).Data; - public Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options) { return _invoiceService.UpdateAsync(id, options); diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 234543a8f6..19437a1ee2 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1,5 +1,4 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Constants; using Bit.Core.Entities; using Bit.Core.Enums; @@ -758,14 +757,14 @@ public class StripePaymentService : IPaymentService }).ToList(); } - private async Task FinalizeSubscriptionChangeAsync(ISubscriber subscriber, + private async Task FinalizeSubscriptionChangeAsync(IStorableSubscriber storableSubscriber, SubscriptionUpdate subscriptionUpdate, DateTime? prorationDate, bool invoiceNow = false) { // remember, when in doubt, throw var subGetOptions = new SubscriptionGetOptions(); // subGetOptions.AddExpand("customer"); subGetOptions.AddExpand("customer.tax"); - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subGetOptions); + var sub = await _stripeAdapter.SubscriptionGetAsync(storableSubscriber.GatewaySubscriptionId, subGetOptions); if (sub == null) { throw new GatewayException("Subscription not found."); @@ -777,7 +776,6 @@ public class StripePaymentService : IPaymentService var chargeNow = collectionMethod == "charge_automatically"; var updatedItemOptions = subscriptionUpdate.UpgradeItemsOptions(sub); var isPm5864DollarThresholdEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5864DollarThreshold); - var isAnnualPlan = sub?.Items?.Data.FirstOrDefault()?.Plan?.Interval == "year"; var subUpdateOptions = new SubscriptionUpdateOptions { @@ -789,10 +787,25 @@ public class StripePaymentService : IPaymentService CollectionMethod = "send_invoice", ProrationDate = prorationDate, }; - if (!invoiceNow && isAnnualPlan && isPm5864DollarThresholdEnabled && sub.Status.Trim() != "trialing") + var immediatelyInvoice = false; + if (!invoiceNow && isPm5864DollarThresholdEnabled && sub.Status.Trim() != "trialing") { - subUpdateOptions.PendingInvoiceItemInterval = - new SubscriptionPendingInvoiceItemIntervalOptions { Interval = "month" }; + var upcomingInvoiceWithChanges = await _stripeAdapter.InvoiceUpcomingAsync(new UpcomingInvoiceOptions + { + Customer = storableSubscriber.GatewayCustomerId, + Subscription = storableSubscriber.GatewaySubscriptionId, + SubscriptionItems = ToInvoiceSubscriptionItemOptions(updatedItemOptions), + SubscriptionProrationBehavior = Constants.CreateProrations, + SubscriptionProrationDate = prorationDate, + SubscriptionBillingCycleAnchor = SubscriptionBillingCycleAnchor.Now + }); + + var isAnnualPlan = sub?.Items?.Data.FirstOrDefault()?.Plan?.Interval == "year"; + immediatelyInvoice = isAnnualPlan && upcomingInvoiceWithChanges.AmountRemaining >= 50000; + + subUpdateOptions.BillingCycleAnchor = immediatelyInvoice + ? SubscriptionBillingCycleAnchor.Now + : SubscriptionBillingCycleAnchor.Unchanged; } var pm5766AutomaticTaxIsEnabled = _featureService.IsEnabled(FeatureFlagKeys.PM5766AutomaticTax); @@ -845,16 +858,21 @@ public class StripePaymentService : IPaymentService { try { - if (chargeNow) + if (!isPm5864DollarThresholdEnabled || immediatelyInvoice || invoiceNow) { - paymentIntentClientSecret = await PayInvoiceAfterSubscriptionChangeAsync(subscriber, invoice); - } - else - { - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, - new InvoiceFinalizeOptions { AutoAdvance = false, }); - await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new InvoiceSendOptions()); - paymentIntentClientSecret = null; + if (chargeNow) + { + paymentIntentClientSecret = await PayInvoiceAfterSubscriptionChangeAsync(storableSubscriber, invoice); + } + else + { + invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, new InvoiceFinalizeOptions + { + AutoAdvance = false, + }); + await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new InvoiceSendOptions()); + paymentIntentClientSecret = null; + } } } catch @@ -925,17 +943,6 @@ public class StripePaymentService : IPaymentService return FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats), prorationDate); } - public Task AdjustSeats( - Provider provider, - StaticStore.Plan plan, - int currentlySubscribedSeats, - int newlySubscribedSeats, - DateTime? prorationDate = null) - => FinalizeSubscriptionChangeAsync( - provider, - new ProviderSubscriptionUpdate(plan.Type, currentlySubscribedSeats, newlySubscribedSeats), - prorationDate); - public Task AdjustSmSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats, DateTime? prorationDate = null) { return FinalizeSubscriptionChangeAsync(organization, new SmSeatSubscriptionUpdate(organization, plan, additionalSeats), prorationDate); @@ -1603,25 +1610,10 @@ public class StripePaymentService : IPaymentService return subscriptionInfo; } - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, new SubscriptionGetOptions - { - Expand = ["test_clock"] - }); - + var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); if (sub != null) { subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(sub); - - if (_featureService.IsEnabled(FeatureFlagKeys.AC1795_UpdatedSubscriptionStatusSection)) - { - var (suspensionDate, unpaidPeriodEndDate) = await GetSuspensionDateAsync(sub); - - if (suspensionDate.HasValue && unpaidPeriodEndDate.HasValue) - { - subscriptionInfo.Subscription.SuspensionDate = suspensionDate; - subscriptionInfo.Subscription.UnpaidPeriodEndDate = unpaidPeriodEndDate; - } - } } if (sub is { CanceledAt: not null } || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) @@ -1938,45 +1930,4 @@ public class StripePaymentService : IPaymentService ? subscriberName : subscriberName[..30]; } - - private async Task<(DateTime?, DateTime?)> GetSuspensionDateAsync(Subscription subscription) - { - if (subscription.Status is not "past_due" && subscription.Status is not "unpaid") - { - return (null, null); - } - - var openInvoices = await _stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions - { - Query = $"subscription:'{subscription.Id}' status:'open'" - }); - - if (openInvoices.Count == 0) - { - return (null, null); - } - - var currentDate = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow; - - switch (subscription.CollectionMethod) - { - case "charge_automatically": - { - var firstOverdueInvoice = openInvoices - .Where(invoice => invoice.PeriodEnd < currentDate && invoice.Attempted) - .MinBy(invoice => invoice.Created); - - return (firstOverdueInvoice?.Created.AddDays(14), firstOverdueInvoice?.PeriodEnd); - } - case "send_invoice": - { - var firstOverdueInvoice = openInvoices - .Where(invoice => invoice.DueDate < currentDate) - .MinBy(invoice => invoice.Created); - - return (firstOverdueInvoice?.DueDate?.AddDays(30), firstOverdueInvoice?.PeriodEnd); - } - default: return (null, null); - } - } } diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index 50b4efe6fb..84037a0a1c 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -221,8 +221,6 @@ public class GlobalSettings : IGlobalSettings private string _connectionString; private string _readOnlyConnectionString; private string _jobSchedulerConnectionString; - public bool SkipDatabasePreparation { get; set; } - public bool DisableDatabaseMaintenanceJobs { get; set; } public string ConnectionString { diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs index af658a409a..5d0becf7b4 100644 --- a/src/Core/Utilities/CoreHelpers.cs +++ b/src/Core/Utilities/CoreHelpers.cs @@ -32,10 +32,6 @@ public static class CoreHelpers private static readonly Random _random = new Random(); private static readonly string RealConnectingIp = "X-Connecting-IP"; private static readonly Regex _whiteSpaceRegex = new Regex(@"\s+"); - private static readonly JsonSerializerOptions _jsonSerializerOptions = new() - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }; /// /// Generate sequential Guid for Sql Server. @@ -782,12 +778,22 @@ public static class CoreHelpers return new T(); } - return System.Text.Json.JsonSerializer.Deserialize(jsonData, _jsonSerializerOptions); + var options = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }; + + return System.Text.Json.JsonSerializer.Deserialize(jsonData, options); } public static string ClassToJsonData(T data) { - return System.Text.Json.JsonSerializer.Serialize(data, _jsonSerializerOptions); + var options = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }; + + return System.Text.Json.JsonSerializer.Serialize(data, options); } public static ICollection AddIfNotExists(this ICollection list, T item) diff --git a/src/Core/Utilities/StaticStore.cs b/src/Core/Utilities/StaticStore.cs index 007f3374e0..dcf63df138 100644 --- a/src/Core/Utilities/StaticStore.cs +++ b/src/Core/Utilities/StaticStore.cs @@ -147,6 +147,7 @@ public static class StaticStore public static Plan GetPlan(PlanType planType) => Plans.SingleOrDefault(p => p.Type == planType); + public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) => SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType); diff --git a/src/Events/appsettings.json b/src/Events/appsettings.json index e72b978f2f..101911bb0d 100644 --- a/src/Events/appsettings.json +++ b/src/Events/appsettings.json @@ -14,6 +14,10 @@ "events": { "connectionString": "SECRET" }, + "documentDb": { + "uri": "SECRET", + "key": "SECRET" + }, "sentry": { "dsn": "SECRET" }, diff --git a/src/EventsProcessor/AzureQueueHostedService.cs b/src/EventsProcessor/AzureQueueHostedService.cs index b1b309b50f..03c0034539 100644 --- a/src/EventsProcessor/AzureQueueHostedService.cs +++ b/src/EventsProcessor/AzureQueueHostedService.cs @@ -30,7 +30,6 @@ public class AzureQueueHostedService : IHostedService, IDisposable _logger.LogInformation(Constants.BypassFiltersEventId, "Starting service."); _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; } @@ -40,10 +39,8 @@ public class AzureQueueHostedService : IHostedService, IDisposable { return; } - _logger.LogWarning("Stopping service."); - - await _cts.CancelAsync(); + _cts.Cancel(); await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); cancellationToken.ThrowIfCancellationRequested(); } @@ -67,15 +64,13 @@ public class AzureQueueHostedService : IHostedService, IDisposable { try { - var messages = await _queueClient.ReceiveMessagesAsync(32, - cancellationToken: cancellationToken); + var messages = await _queueClient.ReceiveMessagesAsync(32); if (messages.Value?.Any() ?? false) { foreach (var message in messages.Value) { await ProcessQueueMessageAsync(message.DecodeMessageText(), cancellationToken); - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt, - cancellationToken); + await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); } } else @@ -83,15 +78,14 @@ public class AzureQueueHostedService : IHostedService, IDisposable await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); } } - catch (Exception ex) + catch (Exception e) { - _logger.LogError(ex, "Error occurred processing message block."); - + _logger.LogError(e, "Exception occurred: " + e.Message); await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); } } - _logger.LogWarning("Done processing messages."); + _logger.LogWarning("Done processing."); } public async Task ProcessQueueMessageAsync(string message, CancellationToken cancellationToken) @@ -104,14 +98,14 @@ public class AzureQueueHostedService : IHostedService, IDisposable try { _logger.LogInformation("Processing message."); - var events = new List(); + using var jsonDocument = JsonDocument.Parse(message); var root = jsonDocument.RootElement; if (root.ValueKind == JsonValueKind.Array) { var indexedEntities = root.Deserialize>() - .SelectMany(EventTableEntity.IndexEvent); + .SelectMany(e => EventTableEntity.IndexEvent(e)); events.AddRange(indexedEntities); } else if (root.ValueKind == JsonValueKind.Object) @@ -120,15 +114,12 @@ public class AzureQueueHostedService : IHostedService, IDisposable events.AddRange(EventTableEntity.IndexEvent(eventMessage)); } - cancellationToken.ThrowIfCancellationRequested(); - await _eventWriteService.CreateManyAsync(events); - _logger.LogInformation("Processed message."); } - catch (JsonException ex) + catch (JsonException) { - _logger.LogError(ex, "Unable to parse message."); + _logger.LogError("JsonReaderException: Unable to parse message."); } } } diff --git a/src/EventsProcessor/appsettings.json b/src/EventsProcessor/appsettings.json index c2c77bcb0d..af0ca259fa 100644 --- a/src/EventsProcessor/appsettings.json +++ b/src/EventsProcessor/appsettings.json @@ -2,6 +2,10 @@ "azureStorageConnectionString": "SECRET", "globalSettings": { "selfHosted": false, - "projectName": "Events Processor" + "projectName": "Events Processor", + "documentDb": { + "uri": "SECRET", + "key": "SECRET" + } } } diff --git a/src/Icons/appsettings.json b/src/Icons/appsettings.json index 6b4e2992e0..65267ef4e9 100644 --- a/src/Icons/appsettings.json +++ b/src/Icons/appsettings.json @@ -1,6 +1,10 @@ { "globalSettings": { - "projectName": "Icons" + "projectName": "Icons", + "documentDb": { + "uri": "SECRET", + "key": "SECRET" + } }, "iconsSettings": { "cacheEnabled": true, diff --git a/src/Identity/appsettings.json b/src/Identity/appsettings.json index 16c3efe46b..e3626b4e16 100644 --- a/src/Identity/appsettings.json +++ b/src/Identity/appsettings.json @@ -27,6 +27,10 @@ "events": { "connectionString": "SECRET" }, + "documentDb": { + "uri": "SECRET", + "key": "SECRET" + }, "sentry": { "dsn": "SECRET" }, diff --git a/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs b/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs index f8448f4198..761545a255 100644 --- a/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs +++ b/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs @@ -14,7 +14,7 @@ public class ProviderPlanRepository( globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString), IProviderPlanRepository { - public async Task> GetByProviderId(Guid providerId) + public async Task GetByProviderId(Guid providerId) { var sqlConnection = new SqlConnection(ConnectionString); @@ -23,6 +23,6 @@ public class ProviderPlanRepository( new { ProviderId = providerId }, commandType: CommandType.StoredProcedure); - return results.ToArray(); + return results.FirstOrDefault(); } } diff --git a/src/Infrastructure.Dapper/Infrastructure.Dapper.csproj b/src/Infrastructure.Dapper/Infrastructure.Dapper.csproj index 046009ef73..6c7ad57d19 100644 --- a/src/Infrastructure.Dapper/Infrastructure.Dapper.csproj +++ b/src/Infrastructure.Dapper/Infrastructure.Dapper.csproj @@ -5,7 +5,7 @@ - + diff --git a/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs b/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs index 386f7115d7..2f9a707b27 100644 --- a/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs +++ b/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs @@ -16,17 +16,14 @@ public class ProviderPlanRepository( mapper, context => context.ProviderPlans), IProviderPlanRepository { - public async Task> GetByProviderId(Guid providerId) + public async Task GetByProviderId(Guid providerId) { using var serviceScope = ServiceScopeFactory.CreateScope(); - var databaseContext = GetDatabaseContext(serviceScope); - var query = from providerPlan in databaseContext.ProviderPlans where providerPlan.ProviderId == providerId select providerPlan; - - return await query.ToArrayAsync(); + return await query.FirstOrDefaultAsync(); } } diff --git a/src/Notifications/appsettings.json b/src/Notifications/appsettings.json index 020d98cbd6..82355a0771 100644 --- a/src/Notifications/appsettings.json +++ b/src/Notifications/appsettings.json @@ -18,6 +18,10 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, + "documentDb": { + "uri": "SECRET", + "key": "SECRET" + }, "sentry": { "dsn": "SECRET" }, diff --git a/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs b/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs index f669e89eb0..90a2335c22 100644 --- a/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs +++ b/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs @@ -64,13 +64,4 @@ public class ApiApplicationFactory : WebApplicationFactoryBase base.Dispose(disposing); SqliteConnection.Dispose(); } - - /// - /// Helper for logging in via client secret. - /// Currently used for Secrets Manager service accounts - /// - public async Task LoginWithClientSecretAsync(Guid clientId, string clientSecret) - { - return await _identityApplicationFactory.TokenFromAccessTokenAsync(clientId, clientSecret); - } } diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/AccessPoliciesControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/AccessPoliciesControllerTests.cs index e1cce68704..b8eb4a7700 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/AccessPoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/AccessPoliciesControllerTests.cs @@ -1,7 +1,7 @@ using System.Net; +using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.SecretsManager.Enums; -using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Api.Models.Response; using Bit.Api.SecretsManager.Models.Request; using Bit.Api.SecretsManager.Models.Response; @@ -28,7 +28,6 @@ public class AccessPoliciesControllerTests : IClassFixture(); _projectRepository = _factory.GetService(); _groupRepository = _factory.GetService(); - _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -56,6 +54,12 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(serviceAccountId, result.ServiceAccountAccessPolicies.First().ServiceAccountId); + Assert.Equal(serviceAccountId, result!.ServiceAccountAccessPolicies.First().ServiceAccountId); Assert.True(result.ServiceAccountAccessPolicies.First().Read); Assert.True(result.ServiceAccountAccessPolicies.First().Write); AssertHelper.AssertRecent(result.ServiceAccountAccessPolicies.First().RevisionDate); @@ -164,7 +168,7 @@ public class AccessPoliciesControllerTests : IClassFixture { new UserProjectAccessPolicy @@ -245,13 +249,13 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(expectedRead, result.Read); + Assert.Equal(expectedRead, result!.Read); Assert.Equal(expectedWrite, result.Write); AssertHelper.AssertRecent(result.RevisionDate); var updatedAccessPolicy = await _accessPolicyRepository.GetByIdAsync(result.Id); Assert.NotNull(updatedAccessPolicy); - Assert.Equal(expectedRead, updatedAccessPolicy.Read); + Assert.Equal(expectedRead, updatedAccessPolicy!.Read); Assert.Equal(expectedWrite, updatedAccessPolicy.Write); AssertHelper.AssertRecent(updatedAccessPolicy.RevisionDate); } @@ -267,7 +271,7 @@ public class AccessPoliciesControllerTests : IClassFixture { new UserProjectAccessPolicy @@ -323,7 +327,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Empty(result.UserAccessPolicies); + Assert.Empty(result!.UserAccessPolicies); Assert.Empty(result.GroupAccessPolicies); Assert.Empty(result.ServiceAccountAccessPolicies); } @@ -353,7 +357,7 @@ public class AccessPoliciesControllerTests : IClassFixture { new UserProjectAccessPolicy @@ -405,7 +409,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result?.ServiceAccountAccessPolicies); - Assert.Single(result.ServiceAccountAccessPolicies); + Assert.Single(result!.ServiceAccountAccessPolicies); } [Theory] @@ -419,7 +423,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.NotEmpty(result.Data); + Assert.NotEmpty(result!.Data); } [Theory] @@ -463,7 +467,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.Empty(result.Data); + Assert.Empty(result!.Data); } [Theory] @@ -503,7 +507,7 @@ public class AccessPoliciesControllerTests : IClassFixture @@ -537,7 +541,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.NotEmpty(result.Data); + Assert.NotEmpty(result!.Data); Assert.Equal(serviceAccount.Id, result.Data.First(x => x.Id == serviceAccount.Id).Id); } @@ -552,7 +556,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.Empty(result.Data); + Assert.Empty(result!.Data); } [Theory] @@ -588,7 +592,7 @@ public class AccessPoliciesControllerTests : IClassFixture @@ -619,7 +623,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.NotEmpty(result.Data); + Assert.NotEmpty(result!.Data); Assert.Equal(project.Id, result.Data.First(x => x.Id == project.Id).Id); } @@ -634,7 +638,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result); - Assert.NotEmpty(result.Data); + Assert.NotEmpty(result!.Data); Assert.Equal(projectId, result.Data.First().GrantedProjectId); var createdAccessPolicy = await _accessPolicyRepository.GetByIdAsync(result.Data.First().Id); Assert.NotNull(createdAccessPolicy); - Assert.Equal(result.Data.First().Read, createdAccessPolicy.Read); + Assert.Equal(result.Data.First().Read, createdAccessPolicy!.Read); Assert.Equal(result.Data.First().Write, createdAccessPolicy.Write); Assert.Equal(result.Data.First().Id, createdAccessPolicy.Id); AssertHelper.AssertRecent(createdAccessPolicy.CreationDate); @@ -743,7 +747,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result); - Assert.Empty(result.Data); + Assert.Empty(result!.Data); } [Fact] @@ -778,7 +782,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result); - Assert.Empty(result.Data); + Assert.Empty(result!.Data); } [Theory] @@ -797,13 +801,13 @@ public class AccessPoliciesControllerTests : IClassFixture { new UserProjectAccessPolicy @@ -821,7 +825,7 @@ public class AccessPoliciesControllerTests : IClassFixture>(); Assert.NotNull(result?.Data); - Assert.NotEmpty(result.Data); + Assert.NotEmpty(result!.Data); Assert.Equal(initData.ServiceAccountId, result.Data.First().ServiceAccountId); Assert.NotNull(result.Data.First().ServiceAccountName); Assert.NotNull(result.Data.First().GrantedProjectName); @@ -838,7 +842,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Empty(result.UserAccessPolicies); + Assert.Empty(result!.UserAccessPolicies); Assert.Empty(result.GroupAccessPolicies); } @@ -877,7 +881,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result?.UserAccessPolicies); - Assert.Single(result.UserAccessPolicies); + Assert.Single(result!.UserAccessPolicies); } [Theory] @@ -920,7 +924,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Empty(result.UserAccessPolicies); + Assert.Empty(result!.UserAccessPolicies); Assert.Empty(result.GroupAccessPolicies); } @@ -1057,7 +1061,7 @@ public class AccessPoliciesControllerTests : IClassFixture(); Assert.NotNull(result?.UserAccessPolicies); - Assert.Single(result.UserAccessPolicies); + Assert.Single(result!.UserAccessPolicies); } [Theory] @@ -1096,7 +1100,7 @@ public class AccessPoliciesControllerTests : IClassFixture { new UserProjectAccessPolicy @@ -1357,6 +1361,35 @@ public class AccessPoliciesControllerTests : IClassFixture SetupUserServiceAccountAccessPolicyRequestAsync( + PermissionType permissionType, Guid userId, Guid serviceAccountId) + { + if (permissionType == PermissionType.RunAsUserWithPermission) + { + var (email, newOrgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); + await LoginAsync(email); + var accessPolicies = new List + { + new UserServiceAccountAccessPolicy + { + GrantedServiceAccountId = serviceAccountId, + OrganizationUserId = newOrgUser.Id, + Read = true, + Write = true, + }, + }; + await _accessPolicyRepository.CreateManyAsync(accessPolicies); + } + + return new AccessPoliciesCreateRequest + { + UserAccessPolicyRequests = new List + { + new() { GranteeId = userId, Read = true, Write = true }, + }, + }; + } + private class RequestSetupData { public Guid ProjectId { get; set; } diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/ProjectsControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/ProjectsControllerTests.cs index 95ddfd678e..523998ee28 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/ProjectsControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/ProjectsControllerTests.cs @@ -1,7 +1,7 @@ using System.Net; +using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.SecretsManager.Enums; -using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Api.Models.Response; using Bit.Api.SecretsManager.Models.Request; using Bit.Api.SecretsManager.Models.Response; @@ -10,6 +10,7 @@ using Bit.Core.Enums; using Bit.Core.SecretsManager.Entities; using Bit.Core.SecretsManager.Repositories; using Bit.Test.Common.Helpers; +using Pipelines.Sockets.Unofficial.Arenas; using Xunit; namespace Bit.Api.IntegrationTest.SecretsManager.Controllers; @@ -23,7 +24,6 @@ public class ProjectsControllerTests : IClassFixture, IAs private readonly ApiApplicationFactory _factory; private readonly IProjectRepository _projectRepository; private readonly IAccessPolicyRepository _accessPolicyRepository; - private readonly LoginHelper _loginHelper; private string _email = null!; private SecretsManagerOrganizationHelper _organizationHelper = null!; @@ -34,7 +34,6 @@ public class ProjectsControllerTests : IClassFixture, IAs _client = _factory.CreateClient(); _projectRepository = _factory.GetService(); _accessPolicyRepository = _factory.GetService(); - _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -50,6 +49,12 @@ public class ProjectsControllerTests : IClassFixture, IAs return Task.CompletedTask; } + private async Task LoginAsync(string email) + { + var tokens = await _factory.LoginAsync(email); + _client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); + } + [Theory] [InlineData(false, false, false)] [InlineData(false, false, true)] @@ -61,7 +66,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task ListByOrganization_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var response = await _client.GetAsync($"/organizations/{org.Id}/projects"); Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); @@ -72,7 +77,7 @@ public class ProjectsControllerTests : IClassFixture, IAs { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); await CreateProjectsAsync(org.Id); @@ -81,7 +86,7 @@ public class ProjectsControllerTests : IClassFixture, IAs var result = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(result); - Assert.Empty(result.Data); + Assert.Empty(result!.Data); } [Theory] @@ -96,7 +101,7 @@ public class ProjectsControllerTests : IClassFixture, IAs var result = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(result); - Assert.NotEmpty(result.Data); + Assert.NotEmpty(result!.Data); Assert.Equal(projectIds.Count, result.Data.Count()); } @@ -111,7 +116,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Create_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var request = new ProjectCreateRequestModel { Name = _mockEncryptedString }; @@ -124,7 +129,7 @@ public class ProjectsControllerTests : IClassFixture, IAs [InlineData(PermissionType.RunAsUserWithPermission)] public async Task Create_AtMaxProjects_BadRequest(PermissionType permissionType) { - var (_, organization) = await SetupProjectsWithAccessAsync(permissionType); + var (_, organization) = await SetupProjectsWithAccessAsync(permissionType, 3); var request = new ProjectCreateRequestModel { Name = _mockEncryptedString }; var response = await _client.PostAsJsonAsync($"/organizations/{organization.Id}/projects", request); @@ -138,14 +143,14 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Create_Success(PermissionType permissionType) { var (org, adminOrgUser) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var orgUserId = adminOrgUser.Id; var currentUserId = adminOrgUser.UserId!.Value; if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); orgUserId = orgUser.Id; currentUserId = orgUser.UserId!.Value; } @@ -157,7 +162,7 @@ public class ProjectsControllerTests : IClassFixture, IAs var result = await response.Content.ReadFromJsonAsync(); Assert.NotNull(result); - Assert.Equal(request.Name, result.Name); + Assert.Equal(request.Name, result!.Name); AssertHelper.AssertRecent(result.RevisionDate); AssertHelper.AssertRecent(result.CreationDate); @@ -191,7 +196,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Update_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var initialProject = await _projectRepository.CreateAsync(new Project { @@ -239,7 +244,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Update_NonExistingProject_NotFound() { await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var request = new ProjectUpdateRequestModel { @@ -257,7 +262,7 @@ public class ProjectsControllerTests : IClassFixture, IAs { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var project = await _projectRepository.CreateAsync(new Project { @@ -287,7 +292,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Get_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var project = await _projectRepository.CreateAsync(new Project { @@ -308,7 +313,7 @@ public class ProjectsControllerTests : IClassFixture, IAs { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var createdProject = await _projectRepository.CreateAsync(new Project { @@ -325,7 +330,7 @@ public class ProjectsControllerTests : IClassFixture, IAs { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var createdProject = await _projectRepository.CreateAsync(new Project { @@ -333,7 +338,7 @@ public class ProjectsControllerTests : IClassFixture, IAs Name = _mockEncryptedString, }); - await _client.PostAsync("/projects/delete", JsonContent.Create(createdProject.Id)); + var deleteResponse = await _client.PostAsync("/projects/delete", JsonContent.Create(createdProject.Id)); var response = await _client.GetAsync($"/projects/{createdProject.Id}"); Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); @@ -367,7 +372,7 @@ public class ProjectsControllerTests : IClassFixture, IAs public async Task Delete_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var projectIds = await CreateProjectsAsync(org.Id); @@ -380,7 +385,7 @@ public class ProjectsControllerTests : IClassFixture, IAs { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var projectIds = await CreateProjectsAsync(org.Id); @@ -389,7 +394,7 @@ public class ProjectsControllerTests : IClassFixture, IAs var results = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(results); Assert.Equal(projectIds.OrderBy(x => x), - results.Data.Select(x => x.Id).OrderBy(x => x)); + results!.Data.Select(x => x.Id).OrderBy(x => x)); Assert.All(results.Data, item => Assert.Equal("access denied", item.Error)); } @@ -406,7 +411,7 @@ public class ProjectsControllerTests : IClassFixture, IAs var results = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(results); Assert.Equal(projectIds.OrderBy(x => x), - results.Data.Select(x => x.Id).OrderBy(x => x)); + results!.Data.Select(x => x.Id).OrderBy(x => x)); Assert.DoesNotContain(results.Data, x => x.Error != null); var projects = await _projectRepository.GetManyWithSecretsByIds(projectIds); @@ -433,7 +438,7 @@ public class ProjectsControllerTests : IClassFixture, IAs int projectsToCreate = 3) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var projectIds = await CreateProjectsAsync(org.Id, projectsToCreate); if (permissionType == PermissionType.RunAsAdmin) @@ -442,7 +447,7 @@ public class ProjectsControllerTests : IClassFixture, IAs } var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var accessPolicies = projectIds.Select(projectId => new UserProjectAccessPolicy { @@ -462,7 +467,7 @@ public class ProjectsControllerTests : IClassFixture, IAs private async Task SetupProjectWithAccessAsync(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var initialProject = await _projectRepository.CreateAsync(new Project { @@ -476,7 +481,7 @@ public class ProjectsControllerTests : IClassFixture, IAs } var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var accessPolicies = new List { diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsControllerTests.cs index 0ff7396eda..4932ad9b9b 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsControllerTests.cs @@ -1,7 +1,7 @@ using System.Net; +using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.SecretsManager.Enums; -using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Api.Models.Response; using Bit.Api.SecretsManager.Models.Request; using Bit.Api.SecretsManager.Models.Response; @@ -23,7 +23,6 @@ public class SecretsControllerTests : IClassFixture, IAsy private readonly ISecretRepository _secretRepository; private readonly IProjectRepository _projectRepository; private readonly IAccessPolicyRepository _accessPolicyRepository; - private readonly LoginHelper _loginHelper; private string _email = null!; private SecretsManagerOrganizationHelper _organizationHelper = null!; @@ -35,7 +34,6 @@ public class SecretsControllerTests : IClassFixture, IAsy _secretRepository = _factory.GetService(); _projectRepository = _factory.GetService(); _accessPolicyRepository = _factory.GetService(); - _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -51,6 +49,12 @@ public class SecretsControllerTests : IClassFixture, IAsy return Task.CompletedTask; } + private async Task LoginAsync(string email) + { + var tokens = await _factory.LoginAsync(email); + _client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); + } + [Theory] [InlineData(false, false, false)] [InlineData(false, false, true)] @@ -62,7 +66,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task ListByOrganization_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var response = await _client.GetAsync($"/organizations/{org.Id}/secrets"); Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); @@ -73,8 +77,8 @@ public class SecretsControllerTests : IClassFixture, IAsy [InlineData(PermissionType.RunAsUserWithPermission)] public async Task ListByOrganization_Success(PermissionType permissionType) { - var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + var (org, orgUserOwner) = await _organizationHelper.Initialize(true, true, true); + await LoginAsync(_email); var project = await _projectRepository.CreateAsync(new Project { @@ -86,7 +90,7 @@ public class SecretsControllerTests : IClassFixture, IAsy if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var accessPolicies = new List { @@ -118,7 +122,7 @@ public class SecretsControllerTests : IClassFixture, IAsy var result = await response.Content.ReadFromJsonAsync(); Assert.NotNull(result); - Assert.NotEmpty(result.Secrets); + Assert.NotEmpty(result!.Secrets); Assert.Equal(secretIds.Count, result.Secrets.Count()); } @@ -133,7 +137,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task Create_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var request = new SecretCreateRequestModel { @@ -150,7 +154,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task CreateWithoutProject_RunAsAdmin_Success() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var request = new SecretCreateRequestModel { @@ -164,7 +168,7 @@ public class SecretsControllerTests : IClassFixture, IAsy var result = await response.Content.ReadFromJsonAsync(); Assert.NotNull(result); - Assert.Equal(request.Key, result.Key); + Assert.Equal(request.Key, result!.Key); Assert.Equal(request.Value, result.Value); Assert.Equal(request.Note, result.Note); AssertHelper.AssertRecent(result.RevisionDate); @@ -184,7 +188,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task CreateWithDifferentProjectOrgId_RunAsAdmin_NotFound() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var anotherOrg = await _organizationHelper.CreateSmOrganizationAsync(); var project = @@ -206,7 +210,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task CreateWithMultipleProjects_RunAsAdmin_BadRequest() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var projectA = await _projectRepository.CreateAsync(new Project { OrganizationId = org.Id, Name = "123A" }); var projectB = await _projectRepository.CreateAsync(new Project { OrganizationId = org.Id, Name = "123B" }); @@ -227,8 +231,8 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task CreateWithoutProject_RunAsUser_NotFound() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); + await LoginAsync(email); var request = new SecretCreateRequestModel { @@ -247,9 +251,9 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task CreateWithProject_Success(PermissionType permissionType) { var (org, orgAdminUser) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); - var accessType = AccessClientType.NoAccessCheck; + AccessClientType accessType = AccessClientType.NoAccessCheck; var project = await _projectRepository.CreateAsync(new Project() { @@ -263,12 +267,12 @@ public class SecretsControllerTests : IClassFixture, IAsy if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); accessType = AccessClientType.User; var accessPolicies = new List { - new UserProjectAccessPolicy + new Core.SecretsManager.Entities.UserProjectAccessPolicy { GrantedProjectId = project.Id, OrganizationUserId = orgUser.Id , Read = true, Write = true, }, @@ -292,7 +296,7 @@ public class SecretsControllerTests : IClassFixture, IAsy var secret = result.Secret; Assert.NotNull(secretResult); - Assert.Equal(secret.Id, secretResult.Id); + Assert.Equal(secret.Id, secretResult!.Id); Assert.Equal(secret.OrganizationId, secretResult.OrganizationId); Assert.Equal(secret.Key, secretResult.Key); Assert.Equal(secret.Value, secretResult.Value); @@ -312,7 +316,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task Get_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -332,7 +336,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task Get_Success(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var project = await _projectRepository.CreateAsync(new Project() { @@ -344,7 +348,7 @@ public class SecretsControllerTests : IClassFixture, IAsy if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var accessPolicies = new List { @@ -357,8 +361,8 @@ public class SecretsControllerTests : IClassFixture, IAsy } else { - var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.Admin, true); - await _loginHelper.LoginAsync(email); + var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.Admin, true); + await LoginAsync(email); } var secret = await _secretRepository.CreateAsync(new Secret @@ -391,7 +395,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task GetSecretsByProject_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var project = await _projectRepository.CreateAsync(new Project { @@ -407,8 +411,8 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task GetSecretsByProject_UserWithNoPermission_EmptyList() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); + await LoginAsync(email); var project = await _projectRepository.CreateAsync(new Project() { @@ -417,7 +421,7 @@ public class SecretsControllerTests : IClassFixture, IAsy Name = _mockEncryptedString }); - await _secretRepository.CreateAsync(new Secret + var secret = await _secretRepository.CreateAsync(new Secret { OrganizationId = org.Id, Key = _mockEncryptedString, @@ -430,8 +434,8 @@ public class SecretsControllerTests : IClassFixture, IAsy response.EnsureSuccessStatusCode(); var result = await response.Content.ReadFromJsonAsync(); Assert.NotNull(result); - Assert.Empty(result.Secrets); - Assert.Empty(result.Projects); + Assert.Empty(result!.Secrets); + Assert.Empty(result!.Projects); } [Theory] @@ -440,7 +444,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task GetSecretsByProject_Success(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var project = await _projectRepository.CreateAsync(new Project() { @@ -452,7 +456,7 @@ public class SecretsControllerTests : IClassFixture, IAsy if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var accessPolicies = new List { @@ -497,7 +501,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task Update_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -521,18 +525,32 @@ public class SecretsControllerTests : IClassFixture, IAsy [Theory] [InlineData(PermissionType.RunAsAdmin)] [InlineData(PermissionType.RunAsUserWithPermission)] - [InlineData(PermissionType.RunAsServiceAccountWithPermission)] public async Task Update_Success(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); + await LoginAsync(_email); + var project = await _projectRepository.CreateAsync(new Project() { - Id = Guid.NewGuid(), + Id = new Guid(), OrganizationId = org.Id, Name = _mockEncryptedString }); - await SetupProjectPermissionAndLoginAsync(permissionType, project); + if (permissionType == PermissionType.RunAsUserWithPermission) + { + var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); + await LoginAsync(email); + + var accessPolicies = new List + { + new UserProjectAccessPolicy + { + GrantedProjectId = project.Id, OrganizationUserId = orgUser.Id, Read = true, Write = true, + }, + }; + await _accessPolicyRepository.CreateManyAsync(accessPolicies); + } var secret = await _secretRepository.CreateAsync(new Secret { @@ -540,7 +558,7 @@ public class SecretsControllerTests : IClassFixture, IAsy Key = _mockEncryptedString, Value = _mockEncryptedString, Note = _mockEncryptedString, - Projects = permissionType != PermissionType.RunAsAdmin ? new List() { project } : null + Projects = permissionType == PermissionType.RunAsUserWithPermission ? new List() { project } : null }); var request = new SecretUpdateRequestModel() @@ -548,7 +566,7 @@ public class SecretsControllerTests : IClassFixture, IAsy Key = _mockEncryptedString, Value = "2.3Uk+WNBIoU5xzmVFNcoWzz==|1MsPIYuRfdOHfu/0uY6H2Q==|/98xy4wb6pHP1VTZ9JcNCYgQjEUMFPlqJgCwRk1YXKg=", Note = _mockEncryptedString, - ProjectIds = permissionType != PermissionType.RunAsAdmin ? new Guid[] { project.Id } : null + ProjectIds = permissionType == PermissionType.RunAsUserWithPermission ? new Guid[] { project.Id } : null }; var response = await _client.PutAsJsonAsync($"/secrets/{secret.Id}", request); @@ -577,7 +595,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task UpdateWithDifferentProjectOrgId_RunAsAdmin_NotFound() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var anotherOrg = await _organizationHelper.CreateSmOrganizationAsync(); var project = await _projectRepository.CreateAsync(new Project { Name = "123", OrganizationId = anotherOrg.Id }); @@ -606,7 +624,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task UpdateWithMultipleProjects_BadRequest() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var projectA = await _projectRepository.CreateAsync(new Project { OrganizationId = org.Id, Name = "123A" }); var projectB = await _projectRepository.CreateAsync(new Project { OrganizationId = org.Id, Name = "123B" }); @@ -642,7 +660,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task Delete_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -662,34 +680,33 @@ public class SecretsControllerTests : IClassFixture, IAsy { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); - var (_, secretIds) = await CreateSecretsAsync(org.Id); + var (_, secretIds) = await CreateSecretsAsync(org.Id, 3); var response = await _client.PostAsync("/secrets/delete", JsonContent.Create(secretIds)); var results = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(results); Assert.Equal(secretIds.OrderBy(x => x), - results.Data.Select(x => x.Id).OrderBy(x => x)); + results!.Data.Select(x => x.Id).OrderBy(x => x)); Assert.All(results.Data, item => Assert.Equal("access denied", item.Error)); } [Theory] [InlineData(PermissionType.RunAsAdmin)] [InlineData(PermissionType.RunAsUserWithPermission)] - [InlineData(PermissionType.RunAsServiceAccountWithPermission)] public async Task Delete_Success(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var (project, secretIds) = await CreateSecretsAsync(org.Id); if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var accessPolicies = new List { @@ -706,8 +723,8 @@ public class SecretsControllerTests : IClassFixture, IAsy var results = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(results?.Data); - Assert.Equal(secretIds.Count, results.Data.Count()); - foreach (var result in results.Data) + Assert.Equal(secretIds.Count, results!.Data.Count()); + foreach (var result in results!.Data) { Assert.Contains(result.Id, secretIds); Assert.Null(result.Error); @@ -728,7 +745,7 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task GetSecretsByIds_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -750,14 +767,14 @@ public class SecretsControllerTests : IClassFixture, IAsy public async Task GetSecretsByIds_Success(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var (project, secretIds) = await CreateSecretsAsync(org.Id); if (permissionType == PermissionType.RunAsUserWithPermission) { var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var accessPolicies = new List { @@ -771,7 +788,7 @@ public class SecretsControllerTests : IClassFixture, IAsy else { var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.Admin, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); } var request = new GetSecretsRequestModel { Ids = secretIds }; @@ -780,8 +797,8 @@ public class SecretsControllerTests : IClassFixture, IAsy response.EnsureSuccessStatusCode(); var result = await response.Content.ReadFromJsonAsync>(); Assert.NotNull(result); - Assert.NotEmpty(result.Data); - Assert.Equal(secretIds.Count, result.Data.Count()); + Assert.NotEmpty(result!.Data); + Assert.Equal(secretIds.Count, result!.Data.Count()); } private async Task<(Project Project, List secretIds)> CreateSecretsAsync(Guid orgId, int numberToCreate = 3) @@ -809,48 +826,4 @@ public class SecretsControllerTests : IClassFixture, IAsy return (project, secretIds); } - - private async Task SetupProjectPermissionAndLoginAsync(PermissionType permissionType, Project project) - { - switch (permissionType) - { - case PermissionType.RunAsAdmin: - { - await _loginHelper.LoginAsync(_email); - break; - } - case PermissionType.RunAsUserWithPermission: - { - var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); - - var accessPolicies = new List - { - new UserProjectAccessPolicy - { - GrantedProjectId = project.Id, OrganizationUserId = orgUser.Id, Read = true, Write = true, - }, - }; - await _accessPolicyRepository.CreateManyAsync(accessPolicies); - break; - } - case PermissionType.RunAsServiceAccountWithPermission: - { - var apiKeyDetails = await _organizationHelper.CreateNewServiceAccountApiKeyAsync(); - await _loginHelper.LoginWithApiKeyAsync(apiKeyDetails); - - var accessPolicies = new List - { - new ServiceAccountProjectAccessPolicy - { - GrantedProjectId = project.Id, ServiceAccountId = apiKeyDetails.ApiKey.ServiceAccountId, Read = true, Write = true, - }, - }; - await _accessPolicyRepository.CreateManyAsync(accessPolicies); - break; - } - default: - throw new ArgumentOutOfRangeException(nameof(permissionType), permissionType, null); - } - } } diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerEventsControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerEventsControllerTests.cs index 036e307d39..4c053c3a2e 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerEventsControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerEventsControllerTests.cs @@ -1,7 +1,6 @@ using System.Net; using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; -using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Core.SecretsManager.Entities; using Bit.Core.SecretsManager.Repositories; using Xunit; diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerPortingControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerPortingControllerTests.cs index ba41c1e862..c57ceb20d9 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerPortingControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/SecretsManagerPortingControllerTests.cs @@ -1,7 +1,8 @@ using System.Net; +using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; -using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Api.SecretsManager.Models.Request; +using Bit.Core.SecretsManager.Repositories; using Xunit; namespace Bit.Api.IntegrationTest.SecretsManager.Controllers; @@ -10,7 +11,8 @@ public class SecretsManagerPortingControllerTests : IClassFixture(); + _accessPolicyRepository = _factory.GetService(); } public async Task InitializeAsync() @@ -35,6 +38,12 @@ public class SecretsManagerPortingControllerTests : IClassFixture(); var secretsList = new List(); @@ -67,7 +76,7 @@ public class SecretsManagerPortingControllerTests : IClassFixture, private readonly HttpClient _client; private readonly ApiApplicationFactory _factory; private readonly ISecretRepository _secretRepository; - private readonly LoginHelper _loginHelper; private string _email = null!; private SecretsManagerOrganizationHelper _organizationHelper = null!; @@ -27,7 +26,6 @@ public class SecretsTrashControllerTests : IClassFixture, _factory = factory; _client = _factory.CreateClient(); _secretRepository = _factory.GetService(); - _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -43,6 +41,12 @@ public class SecretsTrashControllerTests : IClassFixture, return Task.CompletedTask; } + private async Task LoginAsync(string email) + { + var tokens = await _factory.LoginAsync(email); + _client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); + } + [Theory] [InlineData(false, false, false)] [InlineData(false, false, true)] @@ -54,7 +58,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task ListByOrganization_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var response = await _client.GetAsync($"/secrets/{org.Id}/trash"); Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); @@ -65,7 +69,7 @@ public class SecretsTrashControllerTests : IClassFixture, { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var response = await _client.GetAsync($"/secrets/{org.Id}/trash"); Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); @@ -75,7 +79,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task ListByOrganization_Success() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); await _secretRepository.CreateAsync(new Secret { @@ -110,7 +114,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Empty_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var ids = new List { Guid.NewGuid() }; var response = await _client.PostAsJsonAsync($"/secrets/{org.Id}/trash/empty", ids); @@ -122,7 +126,7 @@ public class SecretsTrashControllerTests : IClassFixture, { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var ids = new List { Guid.NewGuid() }; var response = await _client.PostAsJsonAsync($"/secrets/{org.Id}/trash/empty", ids); @@ -133,7 +137,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Empty_Invalid_NotFound() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -151,7 +155,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Empty_Success() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -177,7 +181,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Restore_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) { var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var ids = new List { Guid.NewGuid() }; var response = await _client.PostAsJsonAsync($"/secrets/{org.Id}/trash/restore", ids); @@ -189,7 +193,7 @@ public class SecretsTrashControllerTests : IClassFixture, { var (org, _) = await _organizationHelper.Initialize(true, true, true); var (email, _) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); - await _loginHelper.LoginAsync(email); + await LoginAsync(email); var ids = new List { Guid.NewGuid() }; var response = await _client.PostAsJsonAsync($"/secrets/{org.Id}/trash/restore", ids); @@ -200,7 +204,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Restore_Invalid_NotFound() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { @@ -218,7 +222,7 @@ public class SecretsTrashControllerTests : IClassFixture, public async Task Restore_Success() { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var secret = await _secretRepository.CreateAsync(new Secret { diff --git a/test/Api.IntegrationTest/SecretsManager/Controllers/ServiceAccountsControllerTests.cs b/test/Api.IntegrationTest/SecretsManager/Controllers/ServiceAccountsControllerTests.cs index f25005b269..a482d9b04e 100644 --- a/test/Api.IntegrationTest/SecretsManager/Controllers/ServiceAccountsControllerTests.cs +++ b/test/Api.IntegrationTest/SecretsManager/Controllers/ServiceAccountsControllerTests.cs @@ -1,7 +1,7 @@ using System.Net; +using System.Net.Http.Headers; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.SecretsManager.Enums; -using Bit.Api.IntegrationTest.SecretsManager.Helpers; using Bit.Api.Models.Response; using Bit.Api.SecretsManager.Models.Request; using Bit.Api.SecretsManager.Models.Response; @@ -24,7 +24,6 @@ public class ServiceAccountsControllerTests : IClassFixture(); _accessPolicyRepository = _factory.GetService(); _apiKeyRepository = _factory.GetService(); - _loginHelper = new LoginHelper(_factory, _client); } public async Task InitializeAsync() @@ -56,6 +54,12 @@ public class ServiceAccountsControllerTests : IClassFixture>(); Assert.NotNull(result); - Assert.NotEmpty(result.Data); + Assert.NotEmpty(result!.Data); Assert.Equal(serviceAccountIds.Count, result.Data.Count()); } @@ -95,7 +99,7 @@ public class ServiceAccountsControllerTests : IClassFixture>(); Assert.NotNull(result); - Assert.NotEmpty(result.Data); + Assert.NotEmpty(result!.Data); Assert.Equal(2, result.Data.Count()); } @@ -131,7 +135,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(serviceAccount.Id, result.Id); + Assert.Equal(serviceAccount.Id, result!.Id); Assert.Equal(serviceAccount.OrganizationId, result.OrganizationId); Assert.Equal(serviceAccount.Name, result.Name); Assert.Equal(serviceAccount.CreationDate, result.CreationDate); @@ -199,7 +203,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(request.Name, result.Name); + Assert.Equal(request.Name, result!.Name); AssertHelper.AssertRecent(result.RevisionDate); AssertHelper.AssertRecent(result.CreationDate); @@ -266,7 +270,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(request.Name, result.Name); + Assert.Equal(request.Name, result!.Name); Assert.NotEqual(initialServiceAccount.Name, result.Name); AssertHelper.AssertRecent(result.RevisionDate); Assert.NotEqual(initialServiceAccount.RevisionDate, result.RevisionDate); @@ -349,7 +353,7 @@ public class ServiceAccountsControllerTests : IClassFixture { new UserServiceAccountAccessPolicy @@ -439,7 +443,7 @@ public class ServiceAccountsControllerTests : IClassFixture { new UserServiceAccountAccessPolicy @@ -536,7 +540,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(request.Name, result.Name); + Assert.Equal(request.Name, result!.Name); Assert.NotNull(result.ClientSecret); Assert.Equal(mockExpiresAt, result.ExpireAt); AssertHelper.AssertRecent(result.RevisionDate); @@ -595,7 +599,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(request.Name, result.Name); + Assert.Equal(request.Name, result!.Name); Assert.NotNull(result.ClientSecret); Assert.Equal(mockExpiresAt, result.ExpireAt); AssertHelper.AssertRecent(result.RevisionDate); @@ -631,7 +635,7 @@ public class ServiceAccountsControllerTests : IClassFixture(); Assert.NotNull(result); - Assert.Equal(request.Name, result.Name); + Assert.Equal(request.Name, result!.Name); Assert.NotNull(result.ClientSecret); Assert.Null(result.ExpireAt); AssertHelper.AssertRecent(result.RevisionDate); @@ -695,7 +699,7 @@ public class ServiceAccountsControllerTests : IClassFixture { new UserServiceAccountAccessPolicy @@ -843,7 +847,7 @@ public class ServiceAccountsControllerTests : IClassFixture SetupServiceAccountWithAccessAsync(PermissionType permissionType) { var (org, _) = await _organizationHelper.Initialize(true, true, true); - await _loginHelper.LoginAsync(_email); + await LoginAsync(_email); var initialServiceAccount = await _serviceAccountRepository.CreateAsync(new ServiceAccount { @@ -857,7 +861,7 @@ public class ServiceAccountsControllerTests : IClassFixture { diff --git a/test/Api.IntegrationTest/SecretsManager/Enums/PermissionType.cs b/test/Api.IntegrationTest/SecretsManager/Enums/PermissionType.cs index 972bc7f0be..7f1c4d7b99 100644 --- a/test/Api.IntegrationTest/SecretsManager/Enums/PermissionType.cs +++ b/test/Api.IntegrationTest/SecretsManager/Enums/PermissionType.cs @@ -4,5 +4,4 @@ public enum PermissionType { RunAsAdmin, RunAsUserWithPermission, - RunAsServiceAccountWithPermission, } diff --git a/test/Api.IntegrationTest/SecretsManager/Helpers/LoginHelper.cs b/test/Api.IntegrationTest/SecretsManager/Helpers/LoginHelper.cs deleted file mode 100644 index 9de66bc11e..0000000000 --- a/test/Api.IntegrationTest/SecretsManager/Helpers/LoginHelper.cs +++ /dev/null @@ -1,30 +0,0 @@ -using System.Net.Http.Headers; -using Bit.Api.IntegrationTest.Factories; -using Bit.Core.SecretsManager.Models.Data; - -namespace Bit.Api.IntegrationTest.SecretsManager.Helpers; - -public class LoginHelper -{ - private readonly HttpClient _client; - private readonly ApiApplicationFactory _factory; - - public LoginHelper(ApiApplicationFactory factory, HttpClient client) - { - _factory = factory; - _client = client; - } - - public async Task LoginAsync(string email) - { - var tokens = await _factory.LoginAsync(email); - _client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); - } - - public async Task LoginWithApiKeyAsync(ApiKeyClientSecretDetails apiKeyDetails) - { - var token = await _factory.LoginWithClientSecretAsync(apiKeyDetails.ApiKey.Id, apiKeyDetails.ClientSecret); - _client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token); - _client.DefaultRequestHeaders.Add("service_account_id", apiKeyDetails.ApiKey.ServiceAccountId.ToString()); - } -} diff --git a/test/Api.IntegrationTest/SecretsManager/Helpers/SecretsManagerOrganizationHelper.cs b/test/Api.IntegrationTest/SecretsManager/SecretsManagerOrganizationHelper.cs similarity index 58% rename from test/Api.IntegrationTest/SecretsManager/Helpers/SecretsManagerOrganizationHelper.cs rename to test/Api.IntegrationTest/SecretsManager/SecretsManagerOrganizationHelper.cs index d2d03d979e..fea05de311 100644 --- a/test/Api.IntegrationTest/SecretsManager/Helpers/SecretsManagerOrganizationHelper.cs +++ b/test/Api.IntegrationTest/SecretsManager/SecretsManagerOrganizationHelper.cs @@ -4,12 +4,8 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; -using Bit.Core.SecretsManager.Commands.AccessTokens.Interfaces; -using Bit.Core.SecretsManager.Entities; -using Bit.Core.SecretsManager.Models.Data; -using Bit.Core.SecretsManager.Repositories; -namespace Bit.Api.IntegrationTest.SecretsManager.Helpers; +namespace Bit.Api.IntegrationTest.SecretsManager; public class SecretsManagerOrganizationHelper { @@ -17,20 +13,17 @@ public class SecretsManagerOrganizationHelper private readonly string _ownerEmail; private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IServiceAccountRepository _serviceAccountRepository; - private readonly ICreateAccessTokenCommand _createAccessTokenCommand; - private Organization _organization = null!; - private OrganizationUser _owner = null!; + public Organization _organization = null!; + public OrganizationUser _owner = null!; public SecretsManagerOrganizationHelper(ApiApplicationFactory factory, string ownerEmail) { _factory = factory; _organizationRepository = factory.GetService(); _organizationUserRepository = factory.GetService(); + _ownerEmail = ownerEmail; - _serviceAccountRepository = factory.GetService(); - _createAccessTokenCommand = factory.GetService(); } public async Task<(Organization organization, OrganizationUser owner)> Initialize(bool useSecrets, bool ownerAccessSecrets, bool organizationEnabled) @@ -65,7 +58,8 @@ public class SecretsManagerOrganizationHelper { var email = $"integration-test{Guid.NewGuid()}@bitwarden.com"; await _factory.LoginWithNewAccount(email); - var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, ownerEmail: email, billingEmail: email); + var (organization, owner) = + await OrganizationTestHelpers.SignUpAsync(_factory, ownerEmail: email, billingEmail: email); return organization; } @@ -77,29 +71,4 @@ public class SecretsManagerOrganizationHelper return (email, orgUser); } - - public async Task CreateNewServiceAccountApiKeyAsync() - { - var serviceAccountId = Guid.NewGuid(); - var serviceAccount = new ServiceAccount - { - Id = serviceAccountId, - OrganizationId = _organization.Id, - Name = $"integration-test-{serviceAccountId}sa", - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow - }; - await _serviceAccountRepository.CreateAsync(serviceAccount); - - var apiKey = new ApiKey - { - ServiceAccountId = serviceAccountId, - Name = "integration-token", - Key = Guid.NewGuid().ToString(), - ExpireAt = null, - Scope = "[\"api.secrets\"]", - EncryptedPayload = Guid.NewGuid().ToString() - }; - return await _createAccessTokenCommand.CreateAsync(apiKey); - } } diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs index 9d3c7ebfe5..fdbcc17e46 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -56,7 +56,7 @@ public class OrganizationsControllerTests : IDisposable private readonly IAddSecretsManagerSubscriptionCommand _addSecretsManagerSubscriptionCommand; private readonly IPushNotificationService _pushNotificationService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly ISubscriberQueries _subscriberQueries; + private readonly IGetSubscriptionQuery _getSubscriptionQuery; private readonly IReferenceEventService _referenceEventService; private readonly IOrganizationEnableCollectionEnhancementsCommand _organizationEnableCollectionEnhancementsCommand; @@ -86,7 +86,7 @@ public class OrganizationsControllerTests : IDisposable _addSecretsManagerSubscriptionCommand = Substitute.For(); _pushNotificationService = Substitute.For(); _cancelSubscriptionCommand = Substitute.For(); - _subscriberQueries = Substitute.For(); + _getSubscriptionQuery = Substitute.For(); _referenceEventService = Substitute.For(); _organizationEnableCollectionEnhancementsCommand = Substitute.For(); @@ -113,7 +113,7 @@ public class OrganizationsControllerTests : IDisposable _addSecretsManagerSubscriptionCommand, _pushNotificationService, _cancelSubscriptionCommand, - _subscriberQueries, + _getSubscriptionQuery, _referenceEventService, _organizationEnableCollectionEnhancementsCommand); } diff --git a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs index 4af60689c3..79aa2ca13d 100644 --- a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs @@ -57,7 +57,7 @@ public class AccountsControllerTests : IDisposable private readonly IRotateUserKeyCommand _rotateUserKeyCommand; private readonly IFeatureService _featureService; private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly ISubscriberQueries _subscriberQueries; + private readonly IGetSubscriptionQuery _getSubscriptionQuery; private readonly IReferenceEventService _referenceEventService; private readonly ICurrentContext _currentContext; @@ -90,7 +90,7 @@ public class AccountsControllerTests : IDisposable _rotateUserKeyCommand = Substitute.For(); _featureService = Substitute.For(); _cancelSubscriptionCommand = Substitute.For(); - _subscriberQueries = Substitute.For(); + _getSubscriptionQuery = Substitute.For(); _referenceEventService = Substitute.For(); _currentContext = Substitute.For(); _cipherValidator = @@ -122,7 +122,7 @@ public class AccountsControllerTests : IDisposable _rotateUserKeyCommand, _featureService, _cancelSubscriptionCommand, - _subscriberQueries, + _getSubscriptionQuery, _referenceEventService, _currentContext, _cipherValidator, diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs deleted file mode 100644 index 57480ac116..0000000000 --- a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs +++ /dev/null @@ -1,130 +0,0 @@ -using Bit.Api.Billing.Controllers; -using Bit.Api.Billing.Models; -using Bit.Core; -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Queries; -using Bit.Core.Context; -using Bit.Core.Enums; -using Bit.Core.Services; -using Bit.Core.Utilities; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.AspNetCore.Http.HttpResults; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Stripe; -using Xunit; - -namespace Bit.Api.Test.Billing.Controllers; - -[ControllerCustomize(typeof(ProviderBillingController))] -[SutProviderCustomize] -public class ProviderBillingControllerTests -{ - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_FFDisabled_NotFound( - Guid providerId, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(false); - - var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized( - Guid providerId, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(false); - - var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NoSubscriptionData_NotFound( - Guid providerId, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetSubscriptionData(providerId).ReturnsNull(); - - var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_OK( - Guid providerId, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - var configuredPlans = new List - { - new (Guid.NewGuid(), providerId, PlanType.TeamsMonthly, 50, 10, 30), - new (Guid.NewGuid(), providerId, PlanType.EnterpriseMonthly, 100, 0, 90) - }; - - var subscription = new Subscription - { - Status = "active", - CurrentPeriodEnd = new DateTime(2025, 1, 1), - Customer = new Customer { Discount = new Discount { Coupon = new Coupon { PercentOff = 10 } } } - }; - - var providerSubscriptionData = new ProviderSubscriptionData( - configuredPlans, - subscription); - - sutProvider.GetDependency().GetSubscriptionData(providerId) - .Returns(providerSubscriptionData); - - var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); - - Assert.IsType>(result); - - var providerSubscriptionDTO = ((Ok)result).Value; - - Assert.Equal(providerSubscriptionDTO.Status, subscription.Status); - Assert.Equal(providerSubscriptionDTO.CurrentPeriodEndDate, subscription.CurrentPeriodEnd); - Assert.Equal(providerSubscriptionDTO.DiscountPercentage, subscription.Customer!.Discount!.Coupon!.PercentOff); - - var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var providerTeamsPlan = providerSubscriptionDTO.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name); - Assert.NotNull(providerTeamsPlan); - Assert.Equal(50, providerTeamsPlan.SeatMinimum); - Assert.Equal(10, providerTeamsPlan.PurchasedSeats); - Assert.Equal(30, providerTeamsPlan.AssignedSeats); - Assert.Equal(60 * teamsPlan.PasswordManager.SeatPrice, providerTeamsPlan.Cost); - Assert.Equal("Monthly", providerTeamsPlan.Cadence); - - var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); - var providerEnterprisePlan = providerSubscriptionDTO.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name); - Assert.NotNull(providerEnterprisePlan); - Assert.Equal(100, providerEnterprisePlan.SeatMinimum); - Assert.Equal(0, providerEnterprisePlan.PurchasedSeats); - Assert.Equal(90, providerEnterprisePlan.AssignedSeats); - Assert.Equal(100 * enterprisePlan.PasswordManager.SeatPrice, providerEnterprisePlan.Cost); - Assert.Equal("Monthly", providerEnterprisePlan.Cadence); - } -} diff --git a/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs deleted file mode 100644 index 805683de27..0000000000 --- a/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs +++ /dev/null @@ -1,168 +0,0 @@ -using Bit.Api.Billing.Controllers; -using Bit.Api.Billing.Models; -using Bit.Core; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Commands; -using Bit.Core.Context; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Infrastructure.EntityFramework.AdminConsole.Models.Provider; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.AspNetCore.Http.HttpResults; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Xunit; -using ProviderOrganization = Bit.Core.AdminConsole.Entities.Provider.ProviderOrganization; - -namespace Bit.Api.Test.Billing.Controllers; - -[ControllerCustomize(typeof(ProviderOrganizationController))] -[SutProviderCustomize] -public class ProviderOrganizationControllerTests -{ - [Theory, BitAutoData] - public async Task UpdateAsync_FFDisabled_NotFound( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(false); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(false); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NoProvider_NotFound( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .ReturnsNull(); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NoProviderOrganization_NotFound( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - Provider provider, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); - - sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) - .ReturnsNull(); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NoOrganization_ServerError( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - Provider provider, - ProviderOrganization providerOrganization, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); - - sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) - .Returns(providerOrganization); - - sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) - .ReturnsNull(); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NoContent( - Guid providerId, - Guid providerOrganizationId, - UpdateProviderOrganizationRequestBody requestBody, - Provider provider, - ProviderOrganization providerOrganization, - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); - - sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) - .Returns(providerOrganization); - - sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) - .Returns(organization); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - await sutProvider.GetDependency().Received(1) - .AssignSeatsToClientOrganization( - provider, - organization, - requestBody.AssignedSeats); - - Assert.IsType(result); - } -} diff --git a/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs b/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs deleted file mode 100644 index 918b7c47a2..0000000000 --- a/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs +++ /dev/null @@ -1,339 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Billing; -using Bit.Core.Billing.Commands.Implementations; -using Bit.Core.Billing.Entities; -using Bit.Core.Billing.Queries; -using Bit.Core.Billing.Repositories; -using Bit.Core.Enums; -using Bit.Core.Models.StaticStore; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Utilities; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using Xunit; - -using static Bit.Core.Test.Billing.Utilities; - -namespace Bit.Core.Test.Billing.Commands; - -[SutProviderCustomize] -public class AssignSeatsToClientOrganizationCommandTests -{ - [Theory, BitAutoData] - public Task AssignSeatsToClientOrganization_NullProvider_ArgumentNullException( - Organization organization, - int seats, - SutProvider sutProvider) - => Assert.ThrowsAsync(() => - sutProvider.Sut.AssignSeatsToClientOrganization(null, organization, seats)); - - [Theory, BitAutoData] - public Task AssignSeatsToClientOrganization_NullOrganization_ArgumentNullException( - Provider provider, - int seats, - SutProvider sutProvider) - => Assert.ThrowsAsync(() => - sutProvider.Sut.AssignSeatsToClientOrganization(provider, null, seats)); - - [Theory, BitAutoData] - public Task AssignSeatsToClientOrganization_NegativeSeats_BillingException( - Provider provider, - Organization organization, - SutProvider sutProvider) - => Assert.ThrowsAsync(() => - sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, -5)); - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_CurrentSeatsMatchesNewSeats_NoOp( - Provider provider, - Organization organization, - int seats, - SutProvider sutProvider) - { - organization.PlanType = PlanType.TeamsMonthly; - - organization.Seats = seats; - - await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); - - await sutProvider.GetDependency().DidNotReceive().GetByProviderId(provider.Id); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_OrganizationPlanTypeDoesNotSupportConsolidatedBilling_ContactSupport( - Provider provider, - Organization organization, - int seats, - SutProvider sutProvider) - { - organization.PlanType = PlanType.FamiliesAnnually; - - await ThrowsContactSupportAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_ProviderPlanIsNotConfigured_ContactSupport( - Provider provider, - Organization organization, - int seats, - SutProvider sutProvider) - { - organization.PlanType = PlanType.TeamsMonthly; - - sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(new List - { - new () - { - Id = Guid.NewGuid(), - PlanType = PlanType.TeamsMonthly, - ProviderId = provider.Id - } - }); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_BelowToBelow_Succeeds( - Provider provider, - Organization organization, - SutProvider sutProvider) - { - organization.Seats = 10; - - organization.PlanType = PlanType.TeamsMonthly; - - // Scale up 10 seats - const int seats = 20; - - var providerPlans = new List - { - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.TeamsMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - // 100 minimum - SeatMinimum = 100, - AllocatedSeats = 50 - }, - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.EnterpriseMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - SeatMinimum = 500, - AllocatedSeats = 0 - } - }; - - var providerPlan = providerPlans.First(); - - sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); - - // 50 seats currently assigned with a seat minimum of 100 - sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(50); - - await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); - - // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().AdjustSeats( - Arg.Any(), - Arg.Any(), - Arg.Any(), - Arg.Any()); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.Seats == seats)); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - pPlan => pPlan.AllocatedSeats == 60)); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_BelowToAbove_Succeeds( - Provider provider, - Organization organization, - SutProvider sutProvider) - { - organization.Seats = 10; - - organization.PlanType = PlanType.TeamsMonthly; - - // Scale up 10 seats - const int seats = 20; - - var providerPlans = new List - { - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.TeamsMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - // 100 minimum - SeatMinimum = 100, - AllocatedSeats = 95 - }, - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.EnterpriseMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - SeatMinimum = 500, - AllocatedSeats = 0 - } - }; - - var providerPlan = providerPlans.First(); - - sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); - - // 95 seats currently assigned with a seat minimum of 100 - sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(95); - - await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); - - // 95 current + 10 seat scale = 105 seats, 5 above the minimum - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - providerPlan.SeatMinimum!.Value, - 105); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.Seats == seats)); - - // 105 total seats - 100 minimum = 5 purchased seats - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 5 && pPlan.AllocatedSeats == 105)); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_AboveToAbove_Succeeds( - Provider provider, - Organization organization, - SutProvider sutProvider) - { - organization.Seats = 10; - - organization.PlanType = PlanType.TeamsMonthly; - - // Scale up 10 seats - const int seats = 20; - - var providerPlans = new List - { - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.TeamsMonthly, - ProviderId = provider.Id, - // 10 additional purchased seats - PurchasedSeats = 10, - // 100 seat minimum - SeatMinimum = 100, - AllocatedSeats = 110 - }, - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.EnterpriseMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - SeatMinimum = 500, - AllocatedSeats = 0 - } - }; - - var providerPlan = providerPlans.First(); - - sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); - - // 110 seats currently assigned with a seat minimum of 100 - sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(110); - - await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); - - // 110 current + 10 seat scale up = 120 seats - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - 110, - 120); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.Seats == seats)); - - // 120 total seats - 100 seat minimum = 20 purchased seats - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 20 && pPlan.AllocatedSeats == 120)); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_AboveToBelow_Succeeds( - Provider provider, - Organization organization, - SutProvider sutProvider) - { - organization.Seats = 50; - - organization.PlanType = PlanType.TeamsMonthly; - - // Scale down 30 seats - const int seats = 20; - - var providerPlans = new List - { - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.TeamsMonthly, - ProviderId = provider.Id, - // 10 additional purchased seats - PurchasedSeats = 10, - // 100 seat minimum - SeatMinimum = 100, - AllocatedSeats = 110 - }, - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.EnterpriseMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - SeatMinimum = 500, - AllocatedSeats = 0 - } - }; - - var providerPlan = providerPlans.First(); - - sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); - - // 110 seats currently assigned with a seat minimum of 100 - sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(110); - - await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); - - // 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum. - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - 110, - providerPlan.SeatMinimum!.Value); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.Seats == seats)); - - // Being below the seat minimum means no purchased seats. - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 0 && pPlan.AllocatedSeats == 80)); - } -} diff --git a/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs index 968bfeb84d..5de14f006f 100644 --- a/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs +++ b/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs @@ -1,13 +1,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Commands.Implementations; using Bit.Core.Enums; +using Bit.Core.Exceptions; using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using NSubstitute.ReturnsExtensions; using Xunit; -using static Bit.Core.Test.Billing.Utilities; using BT = Braintree; using S = Stripe; @@ -355,4 +355,13 @@ public class RemovePaymentMethodCommandTests return (braintreeGateway, customerGateway, paymentMethodGateway); } + + private static async Task ThrowsContactSupportAsync(Func function) + { + const string message = "Could not remove your payment method. Please contact support for assistance."; + + var exception = await Assert.ThrowsAsync(function); + + Assert.Equal(message, exception.Message); + } } diff --git a/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs b/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs new file mode 100644 index 0000000000..adae46a791 --- /dev/null +++ b/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs @@ -0,0 +1,104 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Queries.Implementations; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Queries; + +[SutProviderCustomize] +public class GetSubscriptionQueryTests +{ + [Theory, BitAutoData] + public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetSubscription(null)); + + [Theory, BitAutoData] + public async Task GetSubscription_Organization_NoGatewaySubscriptionId_ThrowsGatewayException( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = null; + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(organization)); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Organization_NoSubscription_ThrowsGatewayException( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .ReturnsNull(); + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(organization)); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Organization_Succeeds( + Organization organization, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Equivalent(subscription, gotSubscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_User_NoGatewaySubscriptionId_ThrowsGatewayException( + User user, + SutProvider sutProvider) + { + user.GatewaySubscriptionId = null; + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(user)); + } + + [Theory, BitAutoData] + public async Task GetSubscription_User_NoSubscription_ThrowsGatewayException( + User user, + SutProvider sutProvider) + { + sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + .ReturnsNull(); + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(user)); + } + + [Theory, BitAutoData] + public async Task GetSubscription_User_Succeeds( + User user, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscription(user); + + Assert.Equivalent(subscription, gotSubscription); + } + + private static async Task ThrowsContactSupportAsync(Func function) + { + const string message = "Something went wrong with your request. Please contact support."; + + var exception = await Assert.ThrowsAsync(function); + + Assert.Equal(message, exception.Message); + } +} diff --git a/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs b/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs deleted file mode 100644 index 534444ba94..0000000000 --- a/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs +++ /dev/null @@ -1,154 +0,0 @@ -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Entities; -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Queries; -using Bit.Core.Billing.Queries.Implementations; -using Bit.Core.Billing.Repositories; -using Bit.Core.Enums; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Stripe; -using Xunit; - -namespace Bit.Core.Test.Billing.Queries; - -[SutProviderCustomize] -public class ProviderBillingQueriesTests -{ - #region GetSubscriptionData - - [Theory, BitAutoData] - public async Task GetSubscriptionData_NullProvider_ReturnsNull( - SutProvider sutProvider, - Guid providerId) - { - var providerRepository = sutProvider.GetDependency(); - - providerRepository.GetByIdAsync(providerId).ReturnsNull(); - - var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); - - Assert.Null(subscriptionData); - - await providerRepository.Received(1).GetByIdAsync(providerId); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionData_NullSubscription_ReturnsNull( - SutProvider sutProvider, - Guid providerId, - Provider provider) - { - var providerRepository = sutProvider.GetDependency(); - - providerRepository.GetByIdAsync(providerId).Returns(provider); - - var subscriberQueries = sutProvider.GetDependency(); - - subscriberQueries.GetSubscription(provider).ReturnsNull(); - - var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); - - Assert.Null(subscriptionData); - - await providerRepository.Received(1).GetByIdAsync(providerId); - - await subscriberQueries.Received(1).GetSubscription( - provider, - Arg.Is( - options => options.Expand.Count == 1 && options.Expand.First() == "customer")); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionData_Success( - SutProvider sutProvider, - Guid providerId, - Provider provider) - { - var providerRepository = sutProvider.GetDependency(); - - providerRepository.GetByIdAsync(providerId).Returns(provider); - - var subscriberQueries = sutProvider.GetDependency(); - - var subscription = new Subscription(); - - subscriberQueries.GetSubscription(provider, Arg.Is( - options => options.Expand.Count == 1 && options.Expand.First() == "customer")).Returns(subscription); - - var providerPlanRepository = sutProvider.GetDependency(); - - var enterprisePlan = new ProviderPlan - { - Id = Guid.NewGuid(), - ProviderId = providerId, - PlanType = PlanType.EnterpriseMonthly, - SeatMinimum = 100, - PurchasedSeats = 0, - AllocatedSeats = 0 - }; - - var teamsPlan = new ProviderPlan - { - Id = Guid.NewGuid(), - ProviderId = providerId, - PlanType = PlanType.TeamsMonthly, - SeatMinimum = 50, - PurchasedSeats = 10, - AllocatedSeats = 60 - }; - - var providerPlans = new List - { - enterprisePlan, - teamsPlan, - }; - - providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans); - - var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId); - - Assert.NotNull(subscriptionData); - - Assert.Equivalent(subscriptionData.Subscription, subscription); - - Assert.Equal(2, subscriptionData.ProviderPlans.Count); - - var configuredEnterprisePlan = - subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan => - configuredPlan.PlanType == PlanType.EnterpriseMonthly); - - var configuredTeamsPlan = - subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan => - configuredPlan.PlanType == PlanType.TeamsMonthly); - - Compare(enterprisePlan, configuredEnterprisePlan); - - Compare(teamsPlan, configuredTeamsPlan); - - await providerRepository.Received(1).GetByIdAsync(providerId); - - await subscriberQueries.Received(1).GetSubscription( - provider, - Arg.Is( - options => options.Expand.Count == 1 && options.Expand.First() == "customer")); - - await providerPlanRepository.Received(1).GetByProviderId(providerId); - - return; - - void Compare(ProviderPlan providerPlan, ConfiguredProviderPlan configuredProviderPlan) - { - Assert.NotNull(configuredProviderPlan); - Assert.Equal(providerPlan.Id, configuredProviderPlan.Id); - Assert.Equal(providerPlan.ProviderId, configuredProviderPlan.ProviderId); - Assert.Equal(providerPlan.SeatMinimum!.Value, configuredProviderPlan.SeatMinimum); - Assert.Equal(providerPlan.PurchasedSeats!.Value, configuredProviderPlan.PurchasedSeats); - Assert.Equal(providerPlan.AllocatedSeats!.Value, configuredProviderPlan.AssignedSeats); - } - } - #endregion -} diff --git a/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs b/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs deleted file mode 100644 index 51682a6661..0000000000 --- a/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs +++ /dev/null @@ -1,263 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Billing.Queries.Implementations; -using Bit.Core.Entities; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Stripe; -using Xunit; - -using static Bit.Core.Test.Billing.Utilities; - -namespace Bit.Core.Test.Billing.Queries; - -[SutProviderCustomize] -public class SubscriberQueriesTests -{ - #region GetSubscription - [Theory, BitAutoData] - public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) - => await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetSubscription(null)); - - [Theory, BitAutoData] - public async Task GetSubscription_Organization_NoGatewaySubscriptionId_ReturnsNull( - Organization organization, - SutProvider sutProvider) - { - organization.GatewaySubscriptionId = null; - - var gotSubscription = await sutProvider.Sut.GetSubscription(organization); - - Assert.Null(gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Organization_NoSubscription_ReturnsNull( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) - .ReturnsNull(); - - var gotSubscription = await sutProvider.Sut.GetSubscription(organization); - - Assert.Null(gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Organization_Succeeds( - Organization organization, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscription(organization); - - Assert.Equivalent(subscription, gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_User_NoGatewaySubscriptionId_ReturnsNull( - User user, - SutProvider sutProvider) - { - user.GatewaySubscriptionId = null; - - var gotSubscription = await sutProvider.Sut.GetSubscription(user); - - Assert.Null(gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_User_NoSubscription_ReturnsNull( - User user, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .ReturnsNull(); - - var gotSubscription = await sutProvider.Sut.GetSubscription(user); - - Assert.Null(gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_User_Succeeds( - User user, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscription(user); - - Assert.Equivalent(subscription, gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Provider_NoGatewaySubscriptionId_ReturnsNull( - Provider provider, - SutProvider sutProvider) - { - provider.GatewaySubscriptionId = null; - - var gotSubscription = await sutProvider.Sut.GetSubscription(provider); - - Assert.Null(gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Provider_NoSubscription_ReturnsNull( - Provider provider, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) - .ReturnsNull(); - - var gotSubscription = await sutProvider.Sut.GetSubscription(provider); - - Assert.Null(gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Provider_Succeeds( - Provider provider, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscription(provider); - - Assert.Equivalent(subscription, gotSubscription); - } - #endregion - - #region GetSubscriptionOrThrow - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) - => await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetSubscriptionOrThrow(null)); - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Organization_NoGatewaySubscriptionId_ThrowsGatewayException( - Organization organization, - SutProvider sutProvider) - { - organization.GatewaySubscriptionId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Organization_NoSubscription_ThrowsGatewayException( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Organization_Succeeds( - Organization organization, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization); - - Assert.Equivalent(subscription, gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_User_NoGatewaySubscriptionId_ThrowsGatewayException( - User user, - SutProvider sutProvider) - { - user.GatewaySubscriptionId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_User_NoSubscription_ThrowsGatewayException( - User user, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_User_Succeeds( - User user, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(user); - - Assert.Equivalent(subscription, gotSubscription); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Provider_NoGatewaySubscriptionId_ThrowsGatewayException( - Provider provider, - SutProvider sutProvider) - { - provider.GatewaySubscriptionId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Provider_NoSubscription_ThrowsGatewayException( - Provider provider, - SutProvider sutProvider) - { - sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Provider_Succeeds( - Provider provider, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(provider); - - Assert.Equivalent(subscription, gotSubscription); - } - #endregion -} diff --git a/test/Core.Test/Billing/Utilities.cs b/test/Core.Test/Billing/Utilities.cs index ea9e6c694c..359c010a29 100644 --- a/test/Core.Test/Billing/Utilities.cs +++ b/test/Core.Test/Billing/Utilities.cs @@ -1,4 +1,4 @@ -using Bit.Core.Billing; +using Bit.Core.Exceptions; using Xunit; using static Bit.Core.Billing.Utilities; @@ -11,7 +11,7 @@ public static class Utilities { var contactSupport = ContactSupport(); - var exception = await Assert.ThrowsAsync(function); + var exception = await Assert.ThrowsAsync(function); Assert.Equal(contactSupport.Message, exception.Message); } diff --git a/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs b/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs index 2dc23056d1..def8f5c14c 100644 --- a/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs +++ b/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs @@ -42,23 +42,4 @@ public class IdentityApplicationFactory : WebApplicationFactoryBase return (root.GetProperty("access_token").GetString(), root.GetProperty("refresh_token").GetString()); } - - public async Task TokenFromAccessTokenAsync(Guid clientId, string clientSecret, - DeviceType deviceType = DeviceType.SDK) - { - var context = await Server.PostAsync("/connect/token", - new FormUrlEncodedContent(new Dictionary - { - { "scope", "api.secrets" }, - { "client_id", clientId.ToString() }, - { "client_secret", clientSecret }, - { "grant_type", "client_credentials" }, - { "deviceType", ((int)deviceType).ToString() } - })); - - using var body = await AssertHelper.AssertResponseTypeIs(context); - var root = body.RootElement; - - return root.GetProperty("access_token").GetString(); - } } diff --git a/util/Migrator/DbMigrator.cs b/util/Migrator/DbMigrator.cs index 11b80fac78..a6ca53abdb 100644 --- a/util/Migrator/DbMigrator.cs +++ b/util/Migrator/DbMigrator.cs @@ -13,14 +13,11 @@ public class DbMigrator { private readonly string _connectionString; private readonly ILogger _logger; - private readonly bool _skipDatabasePreparation; - public DbMigrator(string connectionString, ILogger logger = null, - bool skipDatabasePreparation = false) + public DbMigrator(string connectionString, ILogger logger = null) { _connectionString = connectionString; _logger = logger ?? CreateLogger(); - _skipDatabasePreparation = skipDatabasePreparation; } public bool MigrateMsSqlDatabaseWithRetries(bool enableLogging = true, @@ -34,10 +31,7 @@ public class DbMigrator { try { - if (!_skipDatabasePreparation) - { - PrepareDatabase(cancellationToken); - } + PrepareDatabase(cancellationToken); var success = MigrateDatabase(enableLogging, repeatable, folderName, dryRun, cancellationToken); return success; diff --git a/util/Migrator/SqlServerDbMigrator.cs b/util/Migrator/SqlServerDbMigrator.cs index d76b26cfb7..b443260820 100644 --- a/util/Migrator/SqlServerDbMigrator.cs +++ b/util/Migrator/SqlServerDbMigrator.cs @@ -10,8 +10,7 @@ public class SqlServerDbMigrator : IDbMigrator public SqlServerDbMigrator(GlobalSettings globalSettings, ILogger logger) { - _migrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, logger, - globalSettings.SqlServer.SkipDatabasePreparation); + _migrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, logger); } public bool MigrateDatabase(bool enableLogging = true,