From 5237b522e5ba83a2bd804e69aabbb10ccf5c60bb Mon Sep 17 00:00:00 2001
From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com>
Date: Mon, 25 Mar 2024 12:47:15 -0400
Subject: [PATCH 01/21] [deps] Billing: Update Stripe.net to v43.20.0 (#3867)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: Conner Turnbull <133619638+cturnbull-bitwarden@users.noreply.github.com>
---
src/Core/Core.csproj | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj
index ff3c632b5b..fa8bee3cab 100644
--- a/src/Core/Core.csproj
+++ b/src/Core/Core.csproj
@@ -53,7 +53,7 @@
-
+
From 4c1d24b10a2d272f209a0c82f2db86de9e6a051b Mon Sep 17 00:00:00 2001
From: Thomas Rittson <31796059+eliykat@users.noreply.github.com>
Date: Tue, 26 Mar 2024 08:34:55 +1000
Subject: [PATCH 02/21] Use static property for JsonSerializerOptions (#3923)
---
src/Core/Utilities/CoreHelpers.cs | 18 ++++++------------
1 file changed, 6 insertions(+), 12 deletions(-)
diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs
index 5d0becf7b4..af658a409a 100644
--- a/src/Core/Utilities/CoreHelpers.cs
+++ b/src/Core/Utilities/CoreHelpers.cs
@@ -32,6 +32,10 @@ public static class CoreHelpers
private static readonly Random _random = new Random();
private static readonly string RealConnectingIp = "X-Connecting-IP";
private static readonly Regex _whiteSpaceRegex = new Regex(@"\s+");
+ private static readonly JsonSerializerOptions _jsonSerializerOptions = new()
+ {
+ PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
+ };
///
/// Generate sequential Guid for Sql Server.
@@ -778,22 +782,12 @@ public static class CoreHelpers
return new T();
}
- var options = new JsonSerializerOptions
- {
- PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
- };
-
- return System.Text.Json.JsonSerializer.Deserialize(jsonData, options);
+ return System.Text.Json.JsonSerializer.Deserialize(jsonData, _jsonSerializerOptions);
}
public static string ClassToJsonData(T data)
{
- var options = new JsonSerializerOptions
- {
- PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
- };
-
- return System.Text.Json.JsonSerializer.Serialize(data, options);
+ return System.Text.Json.JsonSerializer.Serialize(data, _jsonSerializerOptions);
}
public static ICollection AddIfNotExists(this ICollection list, T item)
From 5355b2b969c4f5e1c9d1d923b185cbda4a561501 Mon Sep 17 00:00:00 2001
From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com>
Date: Tue, 26 Mar 2024 09:50:47 +0100
Subject: [PATCH 03/21] [deps] Tools: Update aws-sdk-net monorepo to
v3.7.300.61 (#3925)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
---
src/Core/Core.csproj | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj
index fa8bee3cab..239e999e79 100644
--- a/src/Core/Core.csproj
+++ b/src/Core/Core.csproj
@@ -21,8 +21,8 @@
-
-
+
+
From 2790687dc2468a126bc53f9613e2bb1e63a5a097 Mon Sep 17 00:00:00 2001
From: Matt Bishop
Date: Wed, 27 Mar 2024 11:20:54 -0400
Subject: [PATCH 04/21] [PM-6938] Allow certain database operations to be
skipped (#3914)
* Centralize database migration logic
* Clean up unused usings
* Prizatize
* Remove verbose flag from Docker invocation
* Allow certain database operations to be skipped
* Readonly
---
src/Admin/Jobs/JobsHostedService.cs | 8 ++++++--
src/Core/Settings/GlobalSettings.cs | 2 ++
util/Migrator/DbMigrator.cs | 10 ++++++++--
util/Migrator/SqlServerDbMigrator.cs | 3 ++-
4 files changed, 18 insertions(+), 5 deletions(-)
diff --git a/src/Admin/Jobs/JobsHostedService.cs b/src/Admin/Jobs/JobsHostedService.cs
index adba27970c..89cf5512c3 100644
--- a/src/Admin/Jobs/JobsHostedService.cs
+++ b/src/Admin/Jobs/JobsHostedService.cs
@@ -76,14 +76,18 @@ public class JobsHostedService : BaseJobsHostedService
{
new Tuple(typeof(DeleteSendsJob), everyFiveMinutesTrigger),
new Tuple(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger),
- new Tuple(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger),
- new Tuple(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger),
new Tuple(typeof(DeleteCiphersJob), everyDayAtMidnightUtc),
new Tuple(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger),
new Tuple(typeof(DeleteAuthRequestsJob), everyFifteenMinutesTrigger),
new Tuple(typeof(DeleteUnverifiedOrganizationDomainsJob), everyDayAtTwoAmUtcTrigger),
};
+ if (!(_globalSettings.SqlServer?.DisableDatabaseMaintenanceJobs ?? false))
+ {
+ jobs.Add(new Tuple(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger));
+ jobs.Add(new Tuple(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger));
+ }
+
if (!_globalSettings.SelfHosted)
{
jobs.Add(new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger));
diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs
index 84037a0a1c..50b4efe6fb 100644
--- a/src/Core/Settings/GlobalSettings.cs
+++ b/src/Core/Settings/GlobalSettings.cs
@@ -221,6 +221,8 @@ public class GlobalSettings : IGlobalSettings
private string _connectionString;
private string _readOnlyConnectionString;
private string _jobSchedulerConnectionString;
+ public bool SkipDatabasePreparation { get; set; }
+ public bool DisableDatabaseMaintenanceJobs { get; set; }
public string ConnectionString
{
diff --git a/util/Migrator/DbMigrator.cs b/util/Migrator/DbMigrator.cs
index a6ca53abdb..11b80fac78 100644
--- a/util/Migrator/DbMigrator.cs
+++ b/util/Migrator/DbMigrator.cs
@@ -13,11 +13,14 @@ public class DbMigrator
{
private readonly string _connectionString;
private readonly ILogger _logger;
+ private readonly bool _skipDatabasePreparation;
- public DbMigrator(string connectionString, ILogger logger = null)
+ public DbMigrator(string connectionString, ILogger logger = null,
+ bool skipDatabasePreparation = false)
{
_connectionString = connectionString;
_logger = logger ?? CreateLogger();
+ _skipDatabasePreparation = skipDatabasePreparation;
}
public bool MigrateMsSqlDatabaseWithRetries(bool enableLogging = true,
@@ -31,7 +34,10 @@ public class DbMigrator
{
try
{
- PrepareDatabase(cancellationToken);
+ if (!_skipDatabasePreparation)
+ {
+ PrepareDatabase(cancellationToken);
+ }
var success = MigrateDatabase(enableLogging, repeatable, folderName, dryRun, cancellationToken);
return success;
diff --git a/util/Migrator/SqlServerDbMigrator.cs b/util/Migrator/SqlServerDbMigrator.cs
index b443260820..d76b26cfb7 100644
--- a/util/Migrator/SqlServerDbMigrator.cs
+++ b/util/Migrator/SqlServerDbMigrator.cs
@@ -10,7 +10,8 @@ public class SqlServerDbMigrator : IDbMigrator
public SqlServerDbMigrator(GlobalSettings globalSettings, ILogger logger)
{
- _migrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, logger);
+ _migrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, logger,
+ globalSettings.SqlServer.SkipDatabasePreparation);
}
public bool MigrateDatabase(bool enableLogging = true,
From a390fcafaf004312772cc804d5c1efd0ac19c8f6 Mon Sep 17 00:00:00 2001
From: Matt Bishop
Date: Wed, 27 Mar 2024 12:35:24 -0400
Subject: [PATCH 05/21] Adjust scan permissions (#3931)
---
.github/workflows/scan.yml | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/scan.yml b/.github/workflows/scan.yml
index 62203804b9..89d75ccf0f 100644
--- a/.github/workflows/scan.yml
+++ b/.github/workflows/scan.yml
@@ -10,8 +10,6 @@ on:
pull_request_target:
types: [opened, synchronize]
-permissions: read-all
-
jobs:
check-run:
name: Check PR run
@@ -22,6 +20,8 @@ jobs:
runs-on: ubuntu-22.04
needs: check-run
permissions:
+ contents: read
+ pull-requests: write
security-events: write
steps:
@@ -43,7 +43,7 @@ jobs:
additional_params: --report-format sarif --output-path . ${{ env.INCREMENTAL }}
- name: Upload Checkmarx results to GitHub
- uses: github/codeql-action/upload-sarif@8a470fddafa5cbb6266ee11b37ef4d8aae19c571 # v3.24.6
+ uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9
with:
sarif_file: cx_result.sarif
@@ -51,6 +51,9 @@ jobs:
name: Quality scan
runs-on: ubuntu-22.04
needs: check-run
+ permissions:
+ contents: read
+ pull-requests: write
steps:
- name: Check out repo
From 728d49ab5dd34aaef9527e416ea74d56c749f446 Mon Sep 17 00:00:00 2001
From: Thomas Rittson <31796059+eliykat@users.noreply.github.com>
Date: Thu, 28 Mar 2024 08:08:35 +1000
Subject: [PATCH 06/21] [AC-1724] Remove BulkCollectionAccess feature flag
(#3928)
---
src/Api/Controllers/CollectionsController.cs | 3 ---
src/Core/Constants.cs | 1 -
2 files changed, 4 deletions(-)
diff --git a/src/Api/Controllers/CollectionsController.cs b/src/Api/Controllers/CollectionsController.cs
index 3eeae17a50..7711e44220 100644
--- a/src/Api/Controllers/CollectionsController.cs
+++ b/src/Api/Controllers/CollectionsController.cs
@@ -2,7 +2,6 @@
using Bit.Api.Models.Response;
using Bit.Api.Utilities;
using Bit.Api.Vault.AuthorizationHandlers.Collections;
-using Bit.Core;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
@@ -11,7 +10,6 @@ using Bit.Core.Models.Data;
using Bit.Core.OrganizationFeatures.OrganizationCollections.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Services;
-using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
@@ -322,7 +320,6 @@ public class CollectionsController : Controller
}
[HttpPost("bulk-access")]
- [RequireFeature(FeatureFlagKeys.BulkCollectionAccess)]
public async Task PostBulkCollectionAccess(Guid orgId, [FromBody] BulkCollectionAccessRequestModel model)
{
// Authorization logic assumes flexible collections is enabled
diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs
index 457b47d458..598a5c062b 100644
--- a/src/Core/Constants.cs
+++ b/src/Core/Constants.cs
@@ -114,7 +114,6 @@ public static class FeatureFlagKeys
///
public const string FlexibleCollections = "flexible-collections-disabled-do-not-use";
public const string FlexibleCollectionsV1 = "flexible-collections-v-1"; // v-1 is intentional
- public const string BulkCollectionAccess = "bulk-collection-access";
public const string ItemShare = "item-share";
public const string KeyRotationImprovements = "key-rotation-improvements";
public const string DuoRedirect = "duo-redirect";
From 46dba1519432047e7e58e259083e575e5bad208e Mon Sep 17 00:00:00 2001
From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com>
Date: Thu, 28 Mar 2024 10:04:31 +0100
Subject: [PATCH 07/21] [deps] Tools: Update aws-sdk-net monorepo to
v3.7.300.63 (#3933)
Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
---
src/Core/Core.csproj | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj
index 239e999e79..92c2198242 100644
--- a/src/Core/Core.csproj
+++ b/src/Core/Core.csproj
@@ -21,8 +21,8 @@
-
-
+
+
From ffd988eeda34a8a6a78ee31acbbee75a6ea70947 Mon Sep 17 00:00:00 2001
From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com>
Date: Thu, 28 Mar 2024 08:46:12 -0400
Subject: [PATCH 08/21] [AC-1904] Implement endpoint to retrieve Provider
subscription (#3921)
* Refactor Core.Billing prior to adding new logic
* Add ProviderBillingQueries.GetSubscriptionData
* Add ProviderBillingController.GetSubscriptionAsync
---
.../Controllers/OrganizationsController.cs | 8 +-
.../Auth/Controllers/AccountsController.cs | 8 +-
.../Controllers/ProviderBillingController.cs | 44 +++
.../Billing/Models/ProviderSubscriptionDTO.cs | 47 ++++
.../Entities/Provider/Provider.cs | 22 +-
src/Core/Billing/BillingException.cs | 9 +
.../Commands/ICancelSubscriptionCommand.cs | 2 -
.../Commands/IRemovePaymentMethodCommand.cs | 7 +
.../RemovePaymentMethodCommand.cs | 64 ++---
src/Core/Billing/Entities/ProviderPlan.cs | 3 +-
.../Extensions/ServiceCollectionExtensions.cs | 3 +-
.../Billing/Models/ConfiguredProviderPlan.cs | 22 ++
.../Models/ProviderSubscriptionData.cs | 7 +
.../Billing/Queries/IGetSubscriptionQuery.cs | 18 --
.../Queries/IProviderBillingQueries.cs | 14 +
.../Billing/Queries/ISubscriberQueries.cs | 30 ++
.../Implementations/GetSubscriptionQuery.cs | 36 ---
.../Implementations/ProviderBillingQueries.cs | 49 ++++
.../Implementations/SubscriberQueries.cs | 61 ++++
.../Repositories/IProviderPlanRepository.cs | 2 +-
src/Core/Billing/Utilities.cs | 11 +-
src/Core/Constants.cs | 1 +
.../Repositories/ProviderPlanRepository.cs | 4 +-
.../Repositories/ProviderPlanRepository.cs | 7 +-
.../OrganizationsControllerTests.cs | 6 +-
.../Controllers/AccountsControllerTests.cs | 6 +-
.../RemovePaymentMethodCommandTests.cs | 11 +-
.../Queries/GetSubscriptionQueryTests.cs | 104 -------
.../Queries/ProviderBillingQueriesTests.cs | 151 ++++++++++
.../Billing/Queries/SubscriberQueriesTests.cs | 263 ++++++++++++++++++
test/Core.Test/Billing/Utilities.cs | 4 +-
31 files changed, 786 insertions(+), 238 deletions(-)
create mode 100644 src/Api/Billing/Controllers/ProviderBillingController.cs
create mode 100644 src/Api/Billing/Models/ProviderSubscriptionDTO.cs
create mode 100644 src/Core/Billing/BillingException.cs
create mode 100644 src/Core/Billing/Models/ConfiguredProviderPlan.cs
create mode 100644 src/Core/Billing/Models/ProviderSubscriptionData.cs
delete mode 100644 src/Core/Billing/Queries/IGetSubscriptionQuery.cs
create mode 100644 src/Core/Billing/Queries/IProviderBillingQueries.cs
create mode 100644 src/Core/Billing/Queries/ISubscriberQueries.cs
delete mode 100644 src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs
create mode 100644 src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs
create mode 100644 src/Core/Billing/Queries/Implementations/SubscriberQueries.cs
delete mode 100644 test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs
create mode 100644 test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs
create mode 100644 test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs
diff --git a/src/Api/AdminConsole/Controllers/OrganizationsController.cs b/src/Api/AdminConsole/Controllers/OrganizationsController.cs
index 2a4ba3a1db..822f9635eb 100644
--- a/src/Api/AdminConsole/Controllers/OrganizationsController.cs
+++ b/src/Api/AdminConsole/Controllers/OrganizationsController.cs
@@ -66,7 +66,7 @@ public class OrganizationsController : Controller
private readonly IAddSecretsManagerSubscriptionCommand _addSecretsManagerSubscriptionCommand;
private readonly IPushNotificationService _pushNotificationService;
private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand;
- private readonly IGetSubscriptionQuery _getSubscriptionQuery;
+ private readonly ISubscriberQueries _subscriberQueries;
private readonly IReferenceEventService _referenceEventService;
private readonly IOrganizationEnableCollectionEnhancementsCommand _organizationEnableCollectionEnhancementsCommand;
@@ -93,7 +93,7 @@ public class OrganizationsController : Controller
IAddSecretsManagerSubscriptionCommand addSecretsManagerSubscriptionCommand,
IPushNotificationService pushNotificationService,
ICancelSubscriptionCommand cancelSubscriptionCommand,
- IGetSubscriptionQuery getSubscriptionQuery,
+ ISubscriberQueries subscriberQueries,
IReferenceEventService referenceEventService,
IOrganizationEnableCollectionEnhancementsCommand organizationEnableCollectionEnhancementsCommand)
{
@@ -119,7 +119,7 @@ public class OrganizationsController : Controller
_addSecretsManagerSubscriptionCommand = addSecretsManagerSubscriptionCommand;
_pushNotificationService = pushNotificationService;
_cancelSubscriptionCommand = cancelSubscriptionCommand;
- _getSubscriptionQuery = getSubscriptionQuery;
+ _subscriberQueries = subscriberQueries;
_referenceEventService = referenceEventService;
_organizationEnableCollectionEnhancementsCommand = organizationEnableCollectionEnhancementsCommand;
}
@@ -479,7 +479,7 @@ public class OrganizationsController : Controller
throw new NotFoundException();
}
- var subscription = await _getSubscriptionQuery.GetSubscription(organization);
+ var subscription = await _subscriberQueries.GetSubscriptionOrThrow(organization);
await _cancelSubscriptionCommand.CancelSubscription(subscription,
new OffboardingSurveyResponse
diff --git a/src/Api/Auth/Controllers/AccountsController.cs b/src/Api/Auth/Controllers/AccountsController.cs
index 29ede684be..5f1910fb28 100644
--- a/src/Api/Auth/Controllers/AccountsController.cs
+++ b/src/Api/Auth/Controllers/AccountsController.cs
@@ -69,7 +69,7 @@ public class AccountsController : Controller
private readonly IRotateUserKeyCommand _rotateUserKeyCommand;
private readonly IFeatureService _featureService;
private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand;
- private readonly IGetSubscriptionQuery _getSubscriptionQuery;
+ private readonly ISubscriberQueries _subscriberQueries;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
@@ -104,7 +104,7 @@ public class AccountsController : Controller
IRotateUserKeyCommand rotateUserKeyCommand,
IFeatureService featureService,
ICancelSubscriptionCommand cancelSubscriptionCommand,
- IGetSubscriptionQuery getSubscriptionQuery,
+ ISubscriberQueries subscriberQueries,
IReferenceEventService referenceEventService,
ICurrentContext currentContext,
IRotationValidator, IEnumerable> cipherValidator,
@@ -133,7 +133,7 @@ public class AccountsController : Controller
_rotateUserKeyCommand = rotateUserKeyCommand;
_featureService = featureService;
_cancelSubscriptionCommand = cancelSubscriptionCommand;
- _getSubscriptionQuery = getSubscriptionQuery;
+ _subscriberQueries = subscriberQueries;
_referenceEventService = referenceEventService;
_currentContext = currentContext;
_cipherValidator = cipherValidator;
@@ -831,7 +831,7 @@ public class AccountsController : Controller
throw new UnauthorizedAccessException();
}
- var subscription = await _getSubscriptionQuery.GetSubscription(user);
+ var subscription = await _subscriberQueries.GetSubscriptionOrThrow(user);
await _cancelSubscriptionCommand.CancelSubscription(subscription,
new OffboardingSurveyResponse
diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs
new file mode 100644
index 0000000000..583a5937e4
--- /dev/null
+++ b/src/Api/Billing/Controllers/ProviderBillingController.cs
@@ -0,0 +1,44 @@
+using Bit.Api.Billing.Models;
+using Bit.Core;
+using Bit.Core.Billing.Queries;
+using Bit.Core.Context;
+using Bit.Core.Services;
+using Microsoft.AspNetCore.Authorization;
+using Microsoft.AspNetCore.Mvc;
+
+namespace Bit.Api.Billing.Controllers;
+
+[Route("providers/{providerId:guid}/billing")]
+[Authorize("Application")]
+public class ProviderBillingController(
+ ICurrentContext currentContext,
+ IFeatureService featureService,
+ IProviderBillingQueries providerBillingQueries) : Controller
+{
+ [HttpGet("subscription")]
+ public async Task GetSubscriptionAsync([FromRoute] Guid providerId)
+ {
+ if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
+ {
+ return TypedResults.NotFound();
+ }
+
+ if (!currentContext.ProviderProviderAdmin(providerId))
+ {
+ return TypedResults.Unauthorized();
+ }
+
+ var subscriptionData = await providerBillingQueries.GetSubscriptionData(providerId);
+
+ if (subscriptionData == null)
+ {
+ return TypedResults.NotFound();
+ }
+
+ var (providerPlans, subscription) = subscriptionData;
+
+ var providerSubscriptionDTO = ProviderSubscriptionDTO.From(providerPlans, subscription);
+
+ return TypedResults.Ok(providerSubscriptionDTO);
+ }
+}
diff --git a/src/Api/Billing/Models/ProviderSubscriptionDTO.cs b/src/Api/Billing/Models/ProviderSubscriptionDTO.cs
new file mode 100644
index 0000000000..0e8b8bfb1c
--- /dev/null
+++ b/src/Api/Billing/Models/ProviderSubscriptionDTO.cs
@@ -0,0 +1,47 @@
+using Bit.Core.Billing.Models;
+using Bit.Core.Utilities;
+using Stripe;
+
+namespace Bit.Api.Billing.Models;
+
+public record ProviderSubscriptionDTO(
+ string Status,
+ DateTime CurrentPeriodEndDate,
+ decimal? DiscountPercentage,
+ IEnumerable Plans)
+{
+ private const string _annualCadence = "Annual";
+ private const string _monthlyCadence = "Monthly";
+
+ public static ProviderSubscriptionDTO From(
+ IEnumerable providerPlans,
+ Subscription subscription)
+ {
+ var providerPlansDTO = providerPlans
+ .Select(providerPlan =>
+ {
+ var plan = StaticStore.GetPlan(providerPlan.PlanType);
+ var cost = (providerPlan.SeatMinimum + providerPlan.PurchasedSeats) * plan.PasswordManager.SeatPrice;
+ var cadence = plan.IsAnnual ? _annualCadence : _monthlyCadence;
+ return new ProviderPlanDTO(
+ plan.Name,
+ providerPlan.SeatMinimum,
+ providerPlan.PurchasedSeats,
+ cost,
+ cadence);
+ });
+
+ return new ProviderSubscriptionDTO(
+ subscription.Status,
+ subscription.CurrentPeriodEnd,
+ subscription.Customer?.Discount?.Coupon?.PercentOff,
+ providerPlansDTO);
+ }
+}
+
+public record ProviderPlanDTO(
+ string PlanName,
+ int SeatMinimum,
+ int PurchasedSeats,
+ decimal Cost,
+ string Cadence);
diff --git a/src/Core/AdminConsole/Entities/Provider/Provider.cs b/src/Core/AdminConsole/Entities/Provider/Provider.cs
index ee2b35ed90..e5b794e6b1 100644
--- a/src/Core/AdminConsole/Entities/Provider/Provider.cs
+++ b/src/Core/AdminConsole/Entities/Provider/Provider.cs
@@ -6,7 +6,7 @@ using Bit.Core.Utilities;
namespace Bit.Core.AdminConsole.Entities.Provider;
-public class Provider : ITableObject
+public class Provider : ITableObject, ISubscriber
{
public Guid Id { get; set; }
///
@@ -34,6 +34,26 @@ public class Provider : ITableObject
public string GatewayCustomerId { get; set; }
public string GatewaySubscriptionId { get; set; }
+ public string BillingEmailAddress() => BillingEmail?.ToLowerInvariant().Trim();
+
+ public string BillingName() => DisplayBusinessName();
+
+ public string SubscriberName() => DisplayName();
+
+ public string BraintreeCustomerIdPrefix() => "p";
+
+ public string BraintreeIdField() => "provider_id";
+
+ public string BraintreeCloudRegionField() => "region";
+
+ public bool IsOrganization() => false;
+
+ public bool IsUser() => false;
+
+ public string SubscriberType() => "Provider";
+
+ public bool IsExpired() => false;
+
public void SetNewId()
{
if (Id == default)
diff --git a/src/Core/Billing/BillingException.cs b/src/Core/Billing/BillingException.cs
new file mode 100644
index 0000000000..a6944b3ed6
--- /dev/null
+++ b/src/Core/Billing/BillingException.cs
@@ -0,0 +1,9 @@
+namespace Bit.Core.Billing;
+
+public class BillingException(
+ string clientFriendlyMessage,
+ string internalMessage = null,
+ Exception innerException = null) : Exception(internalMessage, innerException)
+{
+ public string ClientFriendlyMessage { get; set; } = clientFriendlyMessage;
+}
diff --git a/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs b/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs
index b23880e650..88708d3d2e 100644
--- a/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs
+++ b/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs
@@ -1,7 +1,6 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Models;
using Bit.Core.Entities;
-using Bit.Core.Exceptions;
using Stripe;
namespace Bit.Core.Billing.Commands;
@@ -17,7 +16,6 @@ public interface ICancelSubscriptionCommand
/// The or with the subscription to cancel.
/// An DTO containing user-provided feedback on why they are cancelling the subscription.
/// A flag indicating whether to cancel the subscription immediately or at the end of the subscription period.
- /// Thrown when the provided subscription is already in an inactive state.
Task CancelSubscription(
Subscription subscription,
OffboardingSurveyResponse offboardingSurveyResponse,
diff --git a/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs b/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs
index 62bf0d0926..e2be6f45eb 100644
--- a/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs
+++ b/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs
@@ -4,5 +4,12 @@ namespace Bit.Core.Billing.Commands;
public interface IRemovePaymentMethodCommand
{
+ ///
+ /// Attempts to remove an Organization's saved payment method. If the Stripe representing the
+ /// contains a valid "btCustomerId" key in its property,
+ /// this command will attempt to remove the Braintree . Otherwise, it will attempt to remove the
+ /// Stripe .
+ ///
+ /// The organization to remove the saved payment method for.
Task RemovePaymentMethod(Organization organization);
}
diff --git a/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs b/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs
index c5dbb6d927..be8479ea99 100644
--- a/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs
+++ b/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs
@@ -1,55 +1,41 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Enums;
-using Bit.Core.Exceptions;
using Bit.Core.Services;
using Braintree;
using Microsoft.Extensions.Logging;
+using static Bit.Core.Billing.Utilities;
+
namespace Bit.Core.Billing.Commands.Implementations;
-public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand
+public class RemovePaymentMethodCommand(
+ IBraintreeGateway braintreeGateway,
+ ILogger logger,
+ IStripeAdapter stripeAdapter)
+ : IRemovePaymentMethodCommand
{
- private readonly IBraintreeGateway _braintreeGateway;
- private readonly ILogger _logger;
- private readonly IStripeAdapter _stripeAdapter;
-
- public RemovePaymentMethodCommand(
- IBraintreeGateway braintreeGateway,
- ILogger logger,
- IStripeAdapter stripeAdapter)
- {
- _braintreeGateway = braintreeGateway;
- _logger = logger;
- _stripeAdapter = stripeAdapter;
- }
-
public async Task RemovePaymentMethod(Organization organization)
{
- const string braintreeCustomerIdKey = "btCustomerId";
-
- if (organization == null)
- {
- throw new ArgumentNullException(nameof(organization));
- }
+ ArgumentNullException.ThrowIfNull(organization);
if (organization.Gateway is not GatewayType.Stripe || string.IsNullOrEmpty(organization.GatewayCustomerId))
{
throw ContactSupport();
}
- var stripeCustomer = await _stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, new Stripe.CustomerGetOptions
+ var stripeCustomer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, new Stripe.CustomerGetOptions
{
- Expand = new List { "invoice_settings.default_payment_method", "sources" }
+ Expand = ["invoice_settings.default_payment_method", "sources"]
});
if (stripeCustomer == null)
{
- _logger.LogError("Could not find Stripe customer ({ID}) when removing payment method", organization.GatewayCustomerId);
+ logger.LogError("Could not find Stripe customer ({ID}) when removing payment method", organization.GatewayCustomerId);
throw ContactSupport();
}
- if (stripeCustomer.Metadata?.TryGetValue(braintreeCustomerIdKey, out var braintreeCustomerId) ?? false)
+ if (stripeCustomer.Metadata?.TryGetValue(BraintreeCustomerIdKey, out var braintreeCustomerId) ?? false)
{
await RemoveBraintreePaymentMethodAsync(braintreeCustomerId);
}
@@ -61,11 +47,11 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand
private async Task RemoveBraintreePaymentMethodAsync(string braintreeCustomerId)
{
- var customer = await _braintreeGateway.Customer.FindAsync(braintreeCustomerId);
+ var customer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId);
if (customer == null)
{
- _logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId);
+ logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId);
throw ContactSupport();
}
@@ -74,27 +60,27 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand
{
var existingDefaultPaymentMethod = customer.DefaultPaymentMethod;
- var updateCustomerResult = await _braintreeGateway.Customer.UpdateAsync(
+ var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync(
braintreeCustomerId,
new CustomerRequest { DefaultPaymentMethodToken = null });
if (!updateCustomerResult.IsSuccess())
{
- _logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}",
+ logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}",
braintreeCustomerId, updateCustomerResult.Message);
throw ContactSupport();
}
- var deletePaymentMethodResult = await _braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token);
+ var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token);
if (!deletePaymentMethodResult.IsSuccess())
{
- await _braintreeGateway.Customer.UpdateAsync(
+ await braintreeGateway.Customer.UpdateAsync(
braintreeCustomerId,
new CustomerRequest { DefaultPaymentMethodToken = existingDefaultPaymentMethod.Token });
- _logger.LogError(
+ logger.LogError(
"Failed to delete Braintree payment method for Customer ({ID}), re-linked payment method. Message: {Message}",
braintreeCustomerId, deletePaymentMethodResult.Message);
@@ -103,7 +89,7 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand
}
else
{
- _logger.LogWarning("Tried to remove non-existent Braintree payment method for Customer ({ID})", braintreeCustomerId);
+ logger.LogWarning("Tried to remove non-existent Braintree payment method for Customer ({ID})", braintreeCustomerId);
}
}
@@ -116,25 +102,23 @@ public class RemovePaymentMethodCommand : IRemovePaymentMethodCommand
switch (source)
{
case Stripe.BankAccount:
- await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id);
+ await stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id);
break;
case Stripe.Card:
- await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id);
+ await stripeAdapter.CardDeleteAsync(customer.Id, source.Id);
break;
}
}
}
- var paymentMethods = _stripeAdapter.PaymentMethodListAutoPagingAsync(new Stripe.PaymentMethodListOptions
+ var paymentMethods = stripeAdapter.PaymentMethodListAutoPagingAsync(new Stripe.PaymentMethodListOptions
{
Customer = customer.Id
});
await foreach (var paymentMethod in paymentMethods)
{
- await _stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id, new Stripe.PaymentMethodDetachOptions());
+ await stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id, new Stripe.PaymentMethodDetachOptions());
}
}
-
- private static GatewayException ContactSupport() => new("Could not remove your payment method. Please contact support for assistance.");
}
diff --git a/src/Core/Billing/Entities/ProviderPlan.cs b/src/Core/Billing/Entities/ProviderPlan.cs
index 325dbbb156..2f15a539e1 100644
--- a/src/Core/Billing/Entities/ProviderPlan.cs
+++ b/src/Core/Billing/Entities/ProviderPlan.cs
@@ -11,7 +11,6 @@ public class ProviderPlan : ITableObject
public PlanType PlanType { get; set; }
public int? SeatMinimum { get; set; }
public int? PurchasedSeats { get; set; }
- public int? AllocatedSeats { get; set; }
public void SetNewId()
{
@@ -20,4 +19,6 @@ public class ProviderPlan : ITableObject
Id = CoreHelpers.GenerateComb();
}
}
+
+ public bool Configured => SeatMinimum.HasValue && PurchasedSeats.HasValue;
}
diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs
index 113fa4d5b7..751bfdb671 100644
--- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs
+++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs
@@ -17,6 +17,7 @@ public static class ServiceCollectionExtensions
public static void AddBillingQueries(this IServiceCollection services)
{
- services.AddSingleton();
+ services.AddSingleton();
+ services.AddSingleton();
}
}
diff --git a/src/Core/Billing/Models/ConfiguredProviderPlan.cs b/src/Core/Billing/Models/ConfiguredProviderPlan.cs
new file mode 100644
index 0000000000..d5d53b36fa
--- /dev/null
+++ b/src/Core/Billing/Models/ConfiguredProviderPlan.cs
@@ -0,0 +1,22 @@
+using Bit.Core.Billing.Entities;
+using Bit.Core.Enums;
+
+namespace Bit.Core.Billing.Models;
+
+public record ConfiguredProviderPlan(
+ Guid Id,
+ Guid ProviderId,
+ PlanType PlanType,
+ int SeatMinimum,
+ int PurchasedSeats)
+{
+ public static ConfiguredProviderPlan From(ProviderPlan providerPlan) =>
+ providerPlan.Configured
+ ? new ConfiguredProviderPlan(
+ providerPlan.Id,
+ providerPlan.ProviderId,
+ providerPlan.PlanType,
+ providerPlan.SeatMinimum.GetValueOrDefault(0),
+ providerPlan.PurchasedSeats.GetValueOrDefault(0))
+ : null;
+}
diff --git a/src/Core/Billing/Models/ProviderSubscriptionData.cs b/src/Core/Billing/Models/ProviderSubscriptionData.cs
new file mode 100644
index 0000000000..27da6cd226
--- /dev/null
+++ b/src/Core/Billing/Models/ProviderSubscriptionData.cs
@@ -0,0 +1,7 @@
+using Stripe;
+
+namespace Bit.Core.Billing.Models;
+
+public record ProviderSubscriptionData(
+ List ProviderPlans,
+ Subscription Subscription);
diff --git a/src/Core/Billing/Queries/IGetSubscriptionQuery.cs b/src/Core/Billing/Queries/IGetSubscriptionQuery.cs
deleted file mode 100644
index 9ba2a85ed5..0000000000
--- a/src/Core/Billing/Queries/IGetSubscriptionQuery.cs
+++ /dev/null
@@ -1,18 +0,0 @@
-using Bit.Core.Entities;
-using Bit.Core.Exceptions;
-using Stripe;
-
-namespace Bit.Core.Billing.Queries;
-
-public interface IGetSubscriptionQuery
-{
- ///
- /// Retrieves a Stripe using the 's property.
- ///
- /// The organization or user to retrieve the subscription for.
- /// A Stripe .
- /// Thrown when the is .
- /// Thrown when the subscriber's is or empty.
- /// Thrown when the returned from Stripe's API is null.
- Task GetSubscription(ISubscriber subscriber);
-}
diff --git a/src/Core/Billing/Queries/IProviderBillingQueries.cs b/src/Core/Billing/Queries/IProviderBillingQueries.cs
new file mode 100644
index 0000000000..1edfddaf56
--- /dev/null
+++ b/src/Core/Billing/Queries/IProviderBillingQueries.cs
@@ -0,0 +1,14 @@
+using Bit.Core.Billing.Models;
+
+namespace Bit.Core.Billing.Queries;
+
+public interface IProviderBillingQueries
+{
+ ///
+ /// Retrieves a provider's billing subscription data.
+ ///
+ /// The ID of the provider to retrieve subscription data for.
+ /// A object containing the provider's Stripe and their s.
+ /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints.
+ Task GetSubscriptionData(Guid providerId);
+}
diff --git a/src/Core/Billing/Queries/ISubscriberQueries.cs b/src/Core/Billing/Queries/ISubscriberQueries.cs
new file mode 100644
index 0000000000..ea6c0d985e
--- /dev/null
+++ b/src/Core/Billing/Queries/ISubscriberQueries.cs
@@ -0,0 +1,30 @@
+using Bit.Core.Entities;
+using Bit.Core.Exceptions;
+using Stripe;
+
+namespace Bit.Core.Billing.Queries;
+
+public interface ISubscriberQueries
+{
+ ///
+ /// Retrieves a Stripe using the 's property.
+ ///
+ /// The organization, provider or user to retrieve the subscription for.
+ /// Optional parameters that can be passed to Stripe to expand or modify the .
+ /// A Stripe .
+ /// Thrown when the is .
+ /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints.
+ Task GetSubscription(
+ ISubscriber subscriber,
+ SubscriptionGetOptions subscriptionGetOptions = null);
+
+ ///
+ /// Retrieves a Stripe using the 's property.
+ ///
+ /// The organization or user to retrieve the subscription for.
+ /// A Stripe .
+ /// Thrown when the is .
+ /// Thrown when the subscriber's is or empty.
+ /// Thrown when the returned from Stripe's API is null.
+ Task GetSubscriptionOrThrow(ISubscriber subscriber);
+}
diff --git a/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs b/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs
deleted file mode 100644
index c3b0a29552..0000000000
--- a/src/Core/Billing/Queries/Implementations/GetSubscriptionQuery.cs
+++ /dev/null
@@ -1,36 +0,0 @@
-using Bit.Core.Entities;
-using Bit.Core.Services;
-using Microsoft.Extensions.Logging;
-using Stripe;
-
-using static Bit.Core.Billing.Utilities;
-
-namespace Bit.Core.Billing.Queries.Implementations;
-
-public class GetSubscriptionQuery(
- ILogger logger,
- IStripeAdapter stripeAdapter) : IGetSubscriptionQuery
-{
- public async Task GetSubscription(ISubscriber subscriber)
- {
- ArgumentNullException.ThrowIfNull(subscriber);
-
- if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
- {
- logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id);
-
- throw ContactSupport();
- }
-
- var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId);
-
- if (subscription != null)
- {
- return subscription;
- }
-
- logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId);
-
- throw ContactSupport();
- }
-}
diff --git a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs
new file mode 100644
index 0000000000..c921e82969
--- /dev/null
+++ b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs
@@ -0,0 +1,49 @@
+using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.Billing.Models;
+using Bit.Core.Billing.Repositories;
+using Microsoft.Extensions.Logging;
+using Stripe;
+
+namespace Bit.Core.Billing.Queries.Implementations;
+
+public class ProviderBillingQueries(
+ ILogger logger,
+ IProviderPlanRepository providerPlanRepository,
+ IProviderRepository providerRepository,
+ ISubscriberQueries subscriberQueries) : IProviderBillingQueries
+{
+ public async Task GetSubscriptionData(Guid providerId)
+ {
+ var provider = await providerRepository.GetByIdAsync(providerId);
+
+ if (provider == null)
+ {
+ logger.LogError(
+ "Could not find provider ({ID}) when retrieving subscription data.",
+ providerId);
+
+ return null;
+ }
+
+ var subscription = await subscriberQueries.GetSubscription(provider, new SubscriptionGetOptions
+ {
+ Expand = ["customer"]
+ });
+
+ if (subscription == null)
+ {
+ return null;
+ }
+
+ var providerPlans = await providerPlanRepository.GetByProviderId(providerId);
+
+ var configuredProviderPlans = providerPlans
+ .Where(providerPlan => providerPlan.Configured)
+ .Select(ConfiguredProviderPlan.From)
+ .ToList();
+
+ return new ProviderSubscriptionData(
+ configuredProviderPlans,
+ subscription);
+ }
+}
diff --git a/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs b/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs
new file mode 100644
index 0000000000..a160a87595
--- /dev/null
+++ b/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs
@@ -0,0 +1,61 @@
+using Bit.Core.Entities;
+using Bit.Core.Services;
+using Microsoft.Extensions.Logging;
+using Stripe;
+
+using static Bit.Core.Billing.Utilities;
+
+namespace Bit.Core.Billing.Queries.Implementations;
+
+public class SubscriberQueries(
+ ILogger logger,
+ IStripeAdapter stripeAdapter) : ISubscriberQueries
+{
+ public async Task GetSubscription(
+ ISubscriber subscriber,
+ SubscriptionGetOptions subscriptionGetOptions = null)
+ {
+ ArgumentNullException.ThrowIfNull(subscriber);
+
+ if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
+ {
+ logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id);
+
+ return null;
+ }
+
+ var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions);
+
+ if (subscription != null)
+ {
+ return subscription;
+ }
+
+ logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId);
+
+ return null;
+ }
+
+ public async Task GetSubscriptionOrThrow(ISubscriber subscriber)
+ {
+ ArgumentNullException.ThrowIfNull(subscriber);
+
+ if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId))
+ {
+ logger.LogError("Cannot cancel subscription for subscriber ({ID}) with no GatewaySubscriptionId.", subscriber.Id);
+
+ throw ContactSupport();
+ }
+
+ var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId);
+
+ if (subscription != null)
+ {
+ return subscription;
+ }
+
+ logger.LogError("Could not find Stripe subscription ({ID}) to cancel.", subscriber.GatewaySubscriptionId);
+
+ throw ContactSupport();
+ }
+}
diff --git a/src/Core/Billing/Repositories/IProviderPlanRepository.cs b/src/Core/Billing/Repositories/IProviderPlanRepository.cs
index ccfc6ee683..eccbad82bb 100644
--- a/src/Core/Billing/Repositories/IProviderPlanRepository.cs
+++ b/src/Core/Billing/Repositories/IProviderPlanRepository.cs
@@ -5,5 +5,5 @@ namespace Bit.Core.Billing.Repositories;
public interface IProviderPlanRepository : IRepository
{
- Task GetByProviderId(Guid providerId);
+ Task> GetByProviderId(Guid providerId);
}
diff --git a/src/Core/Billing/Utilities.cs b/src/Core/Billing/Utilities.cs
index 54ace07a70..2b06f1ea6c 100644
--- a/src/Core/Billing/Utilities.cs
+++ b/src/Core/Billing/Utilities.cs
@@ -1,8 +1,11 @@
-using Bit.Core.Exceptions;
-
-namespace Bit.Core.Billing;
+namespace Bit.Core.Billing;
public static class Utilities
{
- public static GatewayException ContactSupport() => new("Something went wrong with your request. Please contact support.");
+ public const string BraintreeCustomerIdKey = "btCustomerId";
+
+ public static BillingException ContactSupport(
+ string internalMessage = null,
+ Exception innerException = null) => new("Something went wrong with your request. Please contact support.",
+ internalMessage, innerException);
}
diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs
index 598a5c062b..2b8ff33211 100644
--- a/src/Core/Constants.cs
+++ b/src/Core/Constants.cs
@@ -130,6 +130,7 @@ public static class FeatureFlagKeys
public const string PM5864DollarThreshold = "PM-5864-dollar-threshold";
public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email";
public const string ShowPaymentMethodWarningBanners = "show-payment-method-warning-banners";
+ public const string EnableConsolidatedBilling = "enable-consolidated-billing";
public static List GetAllKeys()
{
diff --git a/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs b/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs
index 761545a255..f8448f4198 100644
--- a/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs
+++ b/src/Infrastructure.Dapper/Billing/Repositories/ProviderPlanRepository.cs
@@ -14,7 +14,7 @@ public class ProviderPlanRepository(
globalSettings.SqlServer.ConnectionString,
globalSettings.SqlServer.ReadOnlyConnectionString), IProviderPlanRepository
{
- public async Task GetByProviderId(Guid providerId)
+ public async Task> GetByProviderId(Guid providerId)
{
var sqlConnection = new SqlConnection(ConnectionString);
@@ -23,6 +23,6 @@ public class ProviderPlanRepository(
new { ProviderId = providerId },
commandType: CommandType.StoredProcedure);
- return results.FirstOrDefault();
+ return results.ToArray();
}
}
diff --git a/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs b/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs
index 2f9a707b27..386f7115d7 100644
--- a/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs
+++ b/src/Infrastructure.EntityFramework/Billing/Repositories/ProviderPlanRepository.cs
@@ -16,14 +16,17 @@ public class ProviderPlanRepository(
mapper,
context => context.ProviderPlans), IProviderPlanRepository
{
- public async Task GetByProviderId(Guid providerId)
+ public async Task> GetByProviderId(Guid providerId)
{
using var serviceScope = ServiceScopeFactory.CreateScope();
+
var databaseContext = GetDatabaseContext(serviceScope);
+
var query =
from providerPlan in databaseContext.ProviderPlans
where providerPlan.ProviderId == providerId
select providerPlan;
- return await query.FirstOrDefaultAsync();
+
+ return await query.ToArrayAsync();
}
}
diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs
index fdbcc17e46..9d3c7ebfe5 100644
--- a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs
+++ b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs
@@ -56,7 +56,7 @@ public class OrganizationsControllerTests : IDisposable
private readonly IAddSecretsManagerSubscriptionCommand _addSecretsManagerSubscriptionCommand;
private readonly IPushNotificationService _pushNotificationService;
private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand;
- private readonly IGetSubscriptionQuery _getSubscriptionQuery;
+ private readonly ISubscriberQueries _subscriberQueries;
private readonly IReferenceEventService _referenceEventService;
private readonly IOrganizationEnableCollectionEnhancementsCommand _organizationEnableCollectionEnhancementsCommand;
@@ -86,7 +86,7 @@ public class OrganizationsControllerTests : IDisposable
_addSecretsManagerSubscriptionCommand = Substitute.For();
_pushNotificationService = Substitute.For();
_cancelSubscriptionCommand = Substitute.For();
- _getSubscriptionQuery = Substitute.For();
+ _subscriberQueries = Substitute.For();
_referenceEventService = Substitute.For();
_organizationEnableCollectionEnhancementsCommand = Substitute.For();
@@ -113,7 +113,7 @@ public class OrganizationsControllerTests : IDisposable
_addSecretsManagerSubscriptionCommand,
_pushNotificationService,
_cancelSubscriptionCommand,
- _getSubscriptionQuery,
+ _subscriberQueries,
_referenceEventService,
_organizationEnableCollectionEnhancementsCommand);
}
diff --git a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs
index 79aa2ca13d..4af60689c3 100644
--- a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs
+++ b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs
@@ -57,7 +57,7 @@ public class AccountsControllerTests : IDisposable
private readonly IRotateUserKeyCommand _rotateUserKeyCommand;
private readonly IFeatureService _featureService;
private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand;
- private readonly IGetSubscriptionQuery _getSubscriptionQuery;
+ private readonly ISubscriberQueries _subscriberQueries;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
@@ -90,7 +90,7 @@ public class AccountsControllerTests : IDisposable
_rotateUserKeyCommand = Substitute.For();
_featureService = Substitute.For();
_cancelSubscriptionCommand = Substitute.For();
- _getSubscriptionQuery = Substitute.For();
+ _subscriberQueries = Substitute.For();
_referenceEventService = Substitute.For();
_currentContext = Substitute.For();
_cipherValidator =
@@ -122,7 +122,7 @@ public class AccountsControllerTests : IDisposable
_rotateUserKeyCommand,
_featureService,
_cancelSubscriptionCommand,
- _getSubscriptionQuery,
+ _subscriberQueries,
_referenceEventService,
_currentContext,
_cipherValidator,
diff --git a/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs
index 5de14f006f..968bfeb84d 100644
--- a/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs
+++ b/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs
@@ -1,13 +1,13 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Commands.Implementations;
using Bit.Core.Enums;
-using Bit.Core.Exceptions;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using NSubstitute.ReturnsExtensions;
using Xunit;
+using static Bit.Core.Test.Billing.Utilities;
using BT = Braintree;
using S = Stripe;
@@ -355,13 +355,4 @@ public class RemovePaymentMethodCommandTests
return (braintreeGateway, customerGateway, paymentMethodGateway);
}
-
- private static async Task ThrowsContactSupportAsync(Func function)
- {
- const string message = "Could not remove your payment method. Please contact support for assistance.";
-
- var exception = await Assert.ThrowsAsync(function);
-
- Assert.Equal(message, exception.Message);
- }
}
diff --git a/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs b/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs
deleted file mode 100644
index adae46a791..0000000000
--- a/test/Core.Test/Billing/Queries/GetSubscriptionQueryTests.cs
+++ /dev/null
@@ -1,104 +0,0 @@
-using Bit.Core.AdminConsole.Entities;
-using Bit.Core.Billing.Queries.Implementations;
-using Bit.Core.Entities;
-using Bit.Core.Exceptions;
-using Bit.Core.Services;
-using Bit.Test.Common.AutoFixture;
-using Bit.Test.Common.AutoFixture.Attributes;
-using NSubstitute;
-using NSubstitute.ReturnsExtensions;
-using Stripe;
-using Xunit;
-
-namespace Bit.Core.Test.Billing.Queries;
-
-[SutProviderCustomize]
-public class GetSubscriptionQueryTests
-{
- [Theory, BitAutoData]
- public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException(
- SutProvider sutProvider)
- => await Assert.ThrowsAsync(
- async () => await sutProvider.Sut.GetSubscription(null));
-
- [Theory, BitAutoData]
- public async Task GetSubscription_Organization_NoGatewaySubscriptionId_ThrowsGatewayException(
- Organization organization,
- SutProvider sutProvider)
- {
- organization.GatewaySubscriptionId = null;
-
- await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(organization));
- }
-
- [Theory, BitAutoData]
- public async Task GetSubscription_Organization_NoSubscription_ThrowsGatewayException(
- Organization organization,
- SutProvider sutProvider)
- {
- sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId)
- .ReturnsNull();
-
- await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(organization));
- }
-
- [Theory, BitAutoData]
- public async Task GetSubscription_Organization_Succeeds(
- Organization organization,
- SutProvider sutProvider)
- {
- var subscription = new Subscription();
-
- sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId)
- .Returns(subscription);
-
- var gotSubscription = await sutProvider.Sut.GetSubscription(organization);
-
- Assert.Equivalent(subscription, gotSubscription);
- }
-
- [Theory, BitAutoData]
- public async Task GetSubscription_User_NoGatewaySubscriptionId_ThrowsGatewayException(
- User user,
- SutProvider sutProvider)
- {
- user.GatewaySubscriptionId = null;
-
- await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(user));
- }
-
- [Theory, BitAutoData]
- public async Task GetSubscription_User_NoSubscription_ThrowsGatewayException(
- User user,
- SutProvider sutProvider)
- {
- sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId)
- .ReturnsNull();
-
- await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscription(user));
- }
-
- [Theory, BitAutoData]
- public async Task GetSubscription_User_Succeeds(
- User user,
- SutProvider sutProvider)
- {
- var subscription = new Subscription();
-
- sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId)
- .Returns(subscription);
-
- var gotSubscription = await sutProvider.Sut.GetSubscription(user);
-
- Assert.Equivalent(subscription, gotSubscription);
- }
-
- private static async Task ThrowsContactSupportAsync(Func function)
- {
- const string message = "Something went wrong with your request. Please contact support.";
-
- var exception = await Assert.ThrowsAsync(function);
-
- Assert.Equal(message, exception.Message);
- }
-}
diff --git a/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs b/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs
new file mode 100644
index 0000000000..0962ed32b1
--- /dev/null
+++ b/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs
@@ -0,0 +1,151 @@
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.Billing.Entities;
+using Bit.Core.Billing.Models;
+using Bit.Core.Billing.Queries;
+using Bit.Core.Billing.Queries.Implementations;
+using Bit.Core.Billing.Repositories;
+using Bit.Core.Enums;
+using Bit.Test.Common.AutoFixture;
+using Bit.Test.Common.AutoFixture.Attributes;
+using NSubstitute;
+using NSubstitute.ReturnsExtensions;
+using Stripe;
+using Xunit;
+
+namespace Bit.Core.Test.Billing.Queries;
+
+[SutProviderCustomize]
+public class ProviderBillingQueriesTests
+{
+ #region GetSubscriptionData
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionData_NullProvider_ReturnsNull(
+ SutProvider sutProvider,
+ Guid providerId)
+ {
+ var providerRepository = sutProvider.GetDependency();
+
+ providerRepository.GetByIdAsync(providerId).ReturnsNull();
+
+ var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId);
+
+ Assert.Null(subscriptionData);
+
+ await providerRepository.Received(1).GetByIdAsync(providerId);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionData_NullSubscription_ReturnsNull(
+ SutProvider sutProvider,
+ Guid providerId,
+ Provider provider)
+ {
+ var providerRepository = sutProvider.GetDependency();
+
+ providerRepository.GetByIdAsync(providerId).Returns(provider);
+
+ var subscriberQueries = sutProvider.GetDependency();
+
+ subscriberQueries.GetSubscription(provider).ReturnsNull();
+
+ var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId);
+
+ Assert.Null(subscriptionData);
+
+ await providerRepository.Received(1).GetByIdAsync(providerId);
+
+ await subscriberQueries.Received(1).GetSubscription(
+ provider,
+ Arg.Is(
+ options => options.Expand.Count == 1 && options.Expand.First() == "customer"));
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionData_Success(
+ SutProvider sutProvider,
+ Guid providerId,
+ Provider provider)
+ {
+ var providerRepository = sutProvider.GetDependency();
+
+ providerRepository.GetByIdAsync(providerId).Returns(provider);
+
+ var subscriberQueries = sutProvider.GetDependency();
+
+ var subscription = new Subscription();
+
+ subscriberQueries.GetSubscription(provider, Arg.Is(
+ options => options.Expand.Count == 1 && options.Expand.First() == "customer")).Returns(subscription);
+
+ var providerPlanRepository = sutProvider.GetDependency();
+
+ var enterprisePlan = new ProviderPlan
+ {
+ Id = Guid.NewGuid(),
+ ProviderId = providerId,
+ PlanType = PlanType.EnterpriseMonthly,
+ SeatMinimum = 100,
+ PurchasedSeats = 0
+ };
+
+ var teamsPlan = new ProviderPlan
+ {
+ Id = Guid.NewGuid(),
+ ProviderId = providerId,
+ PlanType = PlanType.TeamsMonthly,
+ SeatMinimum = 50,
+ PurchasedSeats = 10
+ };
+
+ var providerPlans = new List
+ {
+ enterprisePlan,
+ teamsPlan,
+ };
+
+ providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans);
+
+ var subscriptionData = await sutProvider.Sut.GetSubscriptionData(providerId);
+
+ Assert.NotNull(subscriptionData);
+
+ Assert.Equivalent(subscriptionData.Subscription, subscription);
+
+ Assert.Equal(2, subscriptionData.ProviderPlans.Count);
+
+ var configuredEnterprisePlan =
+ subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan =>
+ configuredPlan.PlanType == PlanType.EnterpriseMonthly);
+
+ var configuredTeamsPlan =
+ subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan =>
+ configuredPlan.PlanType == PlanType.TeamsMonthly);
+
+ Compare(enterprisePlan, configuredEnterprisePlan);
+
+ Compare(teamsPlan, configuredTeamsPlan);
+
+ await providerRepository.Received(1).GetByIdAsync(providerId);
+
+ await subscriberQueries.Received(1).GetSubscription(
+ provider,
+ Arg.Is(
+ options => options.Expand.Count == 1 && options.Expand.First() == "customer"));
+
+ await providerPlanRepository.Received(1).GetByProviderId(providerId);
+
+ return;
+
+ void Compare(ProviderPlan providerPlan, ConfiguredProviderPlan configuredProviderPlan)
+ {
+ Assert.NotNull(configuredProviderPlan);
+ Assert.Equal(providerPlan.Id, configuredProviderPlan.Id);
+ Assert.Equal(providerPlan.ProviderId, configuredProviderPlan.ProviderId);
+ Assert.Equal(providerPlan.SeatMinimum!.Value, configuredProviderPlan.SeatMinimum);
+ Assert.Equal(providerPlan.PurchasedSeats!.Value, configuredProviderPlan.PurchasedSeats);
+ }
+ }
+ #endregion
+}
diff --git a/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs b/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs
new file mode 100644
index 0000000000..51682a6661
--- /dev/null
+++ b/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs
@@ -0,0 +1,263 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.Billing.Queries.Implementations;
+using Bit.Core.Entities;
+using Bit.Core.Services;
+using Bit.Test.Common.AutoFixture;
+using Bit.Test.Common.AutoFixture.Attributes;
+using NSubstitute;
+using NSubstitute.ReturnsExtensions;
+using Stripe;
+using Xunit;
+
+using static Bit.Core.Test.Billing.Utilities;
+
+namespace Bit.Core.Test.Billing.Queries;
+
+[SutProviderCustomize]
+public class SubscriberQueriesTests
+{
+ #region GetSubscription
+ [Theory, BitAutoData]
+ public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException(
+ SutProvider sutProvider)
+ => await Assert.ThrowsAsync(
+ async () => await sutProvider.Sut.GetSubscription(null));
+
+ [Theory, BitAutoData]
+ public async Task GetSubscription_Organization_NoGatewaySubscriptionId_ReturnsNull(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ organization.GatewaySubscriptionId = null;
+
+ var gotSubscription = await sutProvider.Sut.GetSubscription(organization);
+
+ Assert.Null(gotSubscription);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscription_Organization_NoSubscription_ReturnsNull(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId)
+ .ReturnsNull();
+
+ var gotSubscription = await sutProvider.Sut.GetSubscription(organization);
+
+ Assert.Null(gotSubscription);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscription_Organization_Succeeds(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ var subscription = new Subscription();
+
+ sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId)
+ .Returns(subscription);
+
+ var gotSubscription = await sutProvider.Sut.GetSubscription(organization);
+
+ Assert.Equivalent(subscription, gotSubscription);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscription_User_NoGatewaySubscriptionId_ReturnsNull(
+ User user,
+ SutProvider sutProvider)
+ {
+ user.GatewaySubscriptionId = null;
+
+ var gotSubscription = await sutProvider.Sut.GetSubscription(user);
+
+ Assert.Null(gotSubscription);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscription_User_NoSubscription_ReturnsNull(
+ User user,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId)
+ .ReturnsNull();
+
+ var gotSubscription = await sutProvider.Sut.GetSubscription(user);
+
+ Assert.Null(gotSubscription);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscription_User_Succeeds(
+ User user,
+ SutProvider sutProvider)
+ {
+ var subscription = new Subscription();
+
+ sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId)
+ .Returns(subscription);
+
+ var gotSubscription = await sutProvider.Sut.GetSubscription(user);
+
+ Assert.Equivalent(subscription, gotSubscription);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscription_Provider_NoGatewaySubscriptionId_ReturnsNull(
+ Provider provider,
+ SutProvider sutProvider)
+ {
+ provider.GatewaySubscriptionId = null;
+
+ var gotSubscription = await sutProvider.Sut.GetSubscription(provider);
+
+ Assert.Null(gotSubscription);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscription_Provider_NoSubscription_ReturnsNull(
+ Provider provider,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId)
+ .ReturnsNull();
+
+ var gotSubscription = await sutProvider.Sut.GetSubscription(provider);
+
+ Assert.Null(gotSubscription);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscription_Provider_Succeeds(
+ Provider provider,
+ SutProvider sutProvider)
+ {
+ var subscription = new Subscription();
+
+ sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId)
+ .Returns(subscription);
+
+ var gotSubscription = await sutProvider.Sut.GetSubscription(provider);
+
+ Assert.Equivalent(subscription, gotSubscription);
+ }
+ #endregion
+
+ #region GetSubscriptionOrThrow
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionOrThrow_NullSubscriber_ThrowsArgumentNullException(
+ SutProvider sutProvider)
+ => await Assert.ThrowsAsync(
+ async () => await sutProvider.Sut.GetSubscriptionOrThrow(null));
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionOrThrow_Organization_NoGatewaySubscriptionId_ThrowsGatewayException(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ organization.GatewaySubscriptionId = null;
+
+ await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization));
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionOrThrow_Organization_NoSubscription_ThrowsGatewayException(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId)
+ .ReturnsNull();
+
+ await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization));
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionOrThrow_Organization_Succeeds(
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ var subscription = new Subscription();
+
+ sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId)
+ .Returns(subscription);
+
+ var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization);
+
+ Assert.Equivalent(subscription, gotSubscription);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionOrThrow_User_NoGatewaySubscriptionId_ThrowsGatewayException(
+ User user,
+ SutProvider sutProvider)
+ {
+ user.GatewaySubscriptionId = null;
+
+ await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user));
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionOrThrow_User_NoSubscription_ThrowsGatewayException(
+ User user,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId)
+ .ReturnsNull();
+
+ await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(user));
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionOrThrow_User_Succeeds(
+ User user,
+ SutProvider sutProvider)
+ {
+ var subscription = new Subscription();
+
+ sutProvider.GetDependency().SubscriptionGetAsync(user.GatewaySubscriptionId)
+ .Returns(subscription);
+
+ var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(user);
+
+ Assert.Equivalent(subscription, gotSubscription);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionOrThrow_Provider_NoGatewaySubscriptionId_ThrowsGatewayException(
+ Provider provider,
+ SutProvider sutProvider)
+ {
+ provider.GatewaySubscriptionId = null;
+
+ await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider));
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionOrThrow_Provider_NoSubscription_ThrowsGatewayException(
+ Provider provider,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId)
+ .ReturnsNull();
+
+ await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(provider));
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionOrThrow_Provider_Succeeds(
+ Provider provider,
+ SutProvider sutProvider)
+ {
+ var subscription = new Subscription();
+
+ sutProvider.GetDependency().SubscriptionGetAsync(provider.GatewaySubscriptionId)
+ .Returns(subscription);
+
+ var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(provider);
+
+ Assert.Equivalent(subscription, gotSubscription);
+ }
+ #endregion
+}
diff --git a/test/Core.Test/Billing/Utilities.cs b/test/Core.Test/Billing/Utilities.cs
index 359c010a29..ea9e6c694c 100644
--- a/test/Core.Test/Billing/Utilities.cs
+++ b/test/Core.Test/Billing/Utilities.cs
@@ -1,4 +1,4 @@
-using Bit.Core.Exceptions;
+using Bit.Core.Billing;
using Xunit;
using static Bit.Core.Billing.Utilities;
@@ -11,7 +11,7 @@ public static class Utilities
{
var contactSupport = ContactSupport();
- var exception = await Assert.ThrowsAsync(function);
+ var exception = await Assert.ThrowsAsync(function);
Assert.Equal(contactSupport.Message, exception.Message);
}
From c53e5eeab3bd91c797b0088f747b0a199a946c42 Mon Sep 17 00:00:00 2001
From: Matt Bishop
Date: Thu, 28 Mar 2024 16:36:24 -0400
Subject: [PATCH 09/21] [PM-6762] Move to Azure.Data.Tables (#3888)
* Move to Azure.Data.Tables
* Reorder usings
* Add new package to Renovate
* Add manual serialization and deserialization due to enums
* Properly retrieve just the next page
---
.github/renovate.json | 2 +-
src/Core/Core.csproj | 2 +-
src/Core/Models/Data/DictionaryEntity.cs | 134 ---------------
src/Core/Models/Data/EventTableEntity.cs | 159 +++++++++++-------
.../Models/Data/InstallationDeviceEntity.cs | 10 +-
.../TableStorage/EventRepository.cs | 89 +++-------
.../InstallationDeviceRepository.cs | 28 +--
7 files changed, 144 insertions(+), 280 deletions(-)
delete mode 100644 src/Core/Models/Data/DictionaryEntity.cs
diff --git a/.github/renovate.json b/.github/renovate.json
index 18d6e0bb61..91774ca33e 100644
--- a/.github/renovate.json
+++ b/.github/renovate.json
@@ -44,6 +44,7 @@
"matchPackageNames": [
"AspNetCoreRateLimit",
"AspNetCoreRateLimit.Redis",
+ "Azure.Data.Tables",
"Azure.Extensions.AspNetCore.DataProtection.Blobs",
"Azure.Messaging.EventGrid",
"Azure.Messaging.ServiceBus",
@@ -53,7 +54,6 @@
"Fido2.AspNet",
"Duende.IdentityServer",
"Microsoft.Azure.Cosmos",
- "Microsoft.Azure.Cosmos.Table",
"Microsoft.Extensions.Caching.StackExchangeRedis",
"Microsoft.Extensions.Identity.Stores",
"Otp.NET",
diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj
index 92c2198242..4189b9f525 100644
--- a/src/Core/Core.csproj
+++ b/src/Core/Core.csproj
@@ -23,6 +23,7 @@
+
@@ -35,7 +36,6 @@
-
diff --git a/src/Core/Models/Data/DictionaryEntity.cs b/src/Core/Models/Data/DictionaryEntity.cs
deleted file mode 100644
index 72e6c871c7..0000000000
--- a/src/Core/Models/Data/DictionaryEntity.cs
+++ /dev/null
@@ -1,134 +0,0 @@
-using System.Collections;
-using Microsoft.Azure.Cosmos.Table;
-
-namespace Bit.Core.Models.Data;
-
-public class DictionaryEntity : TableEntity, IDictionary
-{
- private IDictionary _properties = new Dictionary();
-
- public ICollection Values => _properties.Values;
-
- public EntityProperty this[string key]
- {
- get => _properties[key];
- set => _properties[key] = value;
- }
-
- public int Count => _properties.Count;
-
- public bool IsReadOnly => _properties.IsReadOnly;
-
- public ICollection Keys => _properties.Keys;
-
- public override void ReadEntity(IDictionary properties,
- OperationContext operationContext)
- {
- _properties = properties;
- }
-
- public override IDictionary WriteEntity(OperationContext operationContext)
- {
- return _properties;
- }
-
- public void Add(string key, EntityProperty value)
- {
- _properties.Add(key, value);
- }
-
- public void Add(string key, bool value)
- {
- _properties.Add(key, new EntityProperty(value));
- }
-
- public void Add(string key, byte[] value)
- {
- _properties.Add(key, new EntityProperty(value));
- }
-
- public void Add(string key, DateTime? value)
- {
- _properties.Add(key, new EntityProperty(value));
- }
-
- public void Add(string key, DateTimeOffset? value)
- {
- _properties.Add(key, new EntityProperty(value));
- }
-
- public void Add(string key, double value)
- {
- _properties.Add(key, new EntityProperty(value));
- }
-
- public void Add(string key, Guid value)
- {
- _properties.Add(key, new EntityProperty(value));
- }
-
- public void Add(string key, int value)
- {
- _properties.Add(key, new EntityProperty(value));
- }
-
- public void Add(string key, long value)
- {
- _properties.Add(key, new EntityProperty(value));
- }
-
- public void Add(string key, string value)
- {
- _properties.Add(key, new EntityProperty(value));
- }
-
- public void Add(KeyValuePair item)
- {
- _properties.Add(item);
- }
-
- public bool ContainsKey(string key)
- {
- return _properties.ContainsKey(key);
- }
-
- public bool Remove(string key)
- {
- return _properties.Remove(key);
- }
-
- public bool TryGetValue(string key, out EntityProperty value)
- {
- return _properties.TryGetValue(key, out value);
- }
-
- public void Clear()
- {
- _properties.Clear();
- }
-
- public bool Contains(KeyValuePair item)
- {
- return _properties.Contains(item);
- }
-
- public void CopyTo(KeyValuePair[] array, int arrayIndex)
- {
- _properties.CopyTo(array, arrayIndex);
- }
-
- public bool Remove(KeyValuePair item)
- {
- return _properties.Remove(item);
- }
-
- public IEnumerator> GetEnumerator()
- {
- return _properties.GetEnumerator();
- }
-
- IEnumerator IEnumerable.GetEnumerator()
- {
- return _properties.GetEnumerator();
- }
-}
diff --git a/src/Core/Models/Data/EventTableEntity.cs b/src/Core/Models/Data/EventTableEntity.cs
index df4a85acaf..69365f4127 100644
--- a/src/Core/Models/Data/EventTableEntity.cs
+++ b/src/Core/Models/Data/EventTableEntity.cs
@@ -1,10 +1,73 @@
-using Bit.Core.Enums;
+using Azure;
+using Azure.Data.Tables;
+using Bit.Core.Enums;
using Bit.Core.Utilities;
-using Microsoft.Azure.Cosmos.Table;
namespace Bit.Core.Models.Data;
-public class EventTableEntity : TableEntity, IEvent
+// used solely for interaction with Azure Table Storage
+public class AzureEvent : ITableEntity
+{
+ public string PartitionKey { get; set; }
+ public string RowKey { get; set; }
+ public DateTimeOffset? Timestamp { get; set; }
+ public ETag ETag { get; set; }
+
+ public DateTime Date { get; set; }
+ public int Type { get; set; }
+ public Guid? UserId { get; set; }
+ public Guid? OrganizationId { get; set; }
+ public Guid? InstallationId { get; set; }
+ public Guid? ProviderId { get; set; }
+ public Guid? CipherId { get; set; }
+ public Guid? CollectionId { get; set; }
+ public Guid? PolicyId { get; set; }
+ public Guid? GroupId { get; set; }
+ public Guid? OrganizationUserId { get; set; }
+ public Guid? ProviderUserId { get; set; }
+ public Guid? ProviderOrganizationId { get; set; }
+ public int? DeviceType { get; set; }
+ public string IpAddress { get; set; }
+ public Guid? ActingUserId { get; set; }
+ public int? SystemUser { get; set; }
+ public string DomainName { get; set; }
+ public Guid? SecretId { get; set; }
+ public Guid? ServiceAccountId { get; set; }
+
+ public EventTableEntity ToEventTableEntity()
+ {
+ return new EventTableEntity
+ {
+ PartitionKey = PartitionKey,
+ RowKey = RowKey,
+ Timestamp = Timestamp,
+ ETag = ETag,
+
+ Date = Date,
+ Type = (EventType)Type,
+ UserId = UserId,
+ OrganizationId = OrganizationId,
+ InstallationId = InstallationId,
+ ProviderId = ProviderId,
+ CipherId = CipherId,
+ CollectionId = CollectionId,
+ PolicyId = PolicyId,
+ GroupId = GroupId,
+ OrganizationUserId = OrganizationUserId,
+ ProviderUserId = ProviderUserId,
+ ProviderOrganizationId = ProviderOrganizationId,
+ DeviceType = DeviceType.HasValue ? (DeviceType)DeviceType.Value : null,
+ IpAddress = IpAddress,
+ ActingUserId = ActingUserId,
+ SystemUser = SystemUser.HasValue ? (EventSystemUser)SystemUser.Value : null,
+ DomainName = DomainName,
+ SecretId = SecretId,
+ ServiceAccountId = ServiceAccountId
+ };
+ }
+}
+
+public class EventTableEntity : IEvent
{
public EventTableEntity() { }
@@ -32,6 +95,11 @@ public class EventTableEntity : TableEntity, IEvent
ServiceAccountId = e.ServiceAccountId;
}
+ public string PartitionKey { get; set; }
+ public string RowKey { get; set; }
+ public DateTimeOffset? Timestamp { get; set; }
+ public ETag ETag { get; set; }
+
public DateTime Date { get; set; }
public EventType Type { get; set; }
public Guid? UserId { get; set; }
@@ -53,65 +121,36 @@ public class EventTableEntity : TableEntity, IEvent
public Guid? SecretId { get; set; }
public Guid? ServiceAccountId { get; set; }
- public override IDictionary WriteEntity(OperationContext operationContext)
+ public AzureEvent ToAzureEvent()
{
- var result = base.WriteEntity(operationContext);
+ return new AzureEvent
+ {
+ PartitionKey = PartitionKey,
+ RowKey = RowKey,
+ Timestamp = Timestamp,
+ ETag = ETag,
- var typeName = nameof(Type);
- if (result.ContainsKey(typeName))
- {
- result[typeName] = new EntityProperty((int)Type);
- }
- else
- {
- result.Add(typeName, new EntityProperty((int)Type));
- }
-
- var deviceTypeName = nameof(DeviceType);
- if (result.ContainsKey(deviceTypeName))
- {
- result[deviceTypeName] = new EntityProperty((int?)DeviceType);
- }
- else
- {
- result.Add(deviceTypeName, new EntityProperty((int?)DeviceType));
- }
-
- var systemUserTypeName = nameof(SystemUser);
- if (result.ContainsKey(systemUserTypeName))
- {
- result[systemUserTypeName] = new EntityProperty((int?)SystemUser);
- }
- else
- {
- result.Add(systemUserTypeName, new EntityProperty((int?)SystemUser));
- }
-
- return result;
- }
-
- public override void ReadEntity(IDictionary properties,
- OperationContext operationContext)
- {
- base.ReadEntity(properties, operationContext);
-
- var typeName = nameof(Type);
- if (properties.ContainsKey(typeName) && properties[typeName].Int32Value.HasValue)
- {
- Type = (EventType)properties[typeName].Int32Value.Value;
- }
-
- var deviceTypeName = nameof(DeviceType);
- if (properties.ContainsKey(deviceTypeName) && properties[deviceTypeName].Int32Value.HasValue)
- {
- DeviceType = (DeviceType)properties[deviceTypeName].Int32Value.Value;
- }
-
- var systemUserTypeName = nameof(SystemUser);
- if (properties.ContainsKey(systemUserTypeName) && properties[systemUserTypeName].Int32Value.HasValue)
- {
- SystemUser = (EventSystemUser)properties[systemUserTypeName].Int32Value.Value;
- }
+ Date = Date,
+ Type = (int)Type,
+ UserId = UserId,
+ OrganizationId = OrganizationId,
+ InstallationId = InstallationId,
+ ProviderId = ProviderId,
+ CipherId = CipherId,
+ CollectionId = CollectionId,
+ PolicyId = PolicyId,
+ GroupId = GroupId,
+ OrganizationUserId = OrganizationUserId,
+ ProviderUserId = ProviderUserId,
+ ProviderOrganizationId = ProviderOrganizationId,
+ DeviceType = DeviceType.HasValue ? (int)DeviceType.Value : null,
+ IpAddress = IpAddress,
+ ActingUserId = ActingUserId,
+ SystemUser = SystemUser.HasValue ? (int)SystemUser.Value : null,
+ DomainName = DomainName,
+ SecretId = SecretId,
+ ServiceAccountId = ServiceAccountId
+ };
}
public static List IndexEvent(EventMessage e)
diff --git a/src/Core/Models/Data/InstallationDeviceEntity.cs b/src/Core/Models/Data/InstallationDeviceEntity.cs
index cb7bf00873..3186efc661 100644
--- a/src/Core/Models/Data/InstallationDeviceEntity.cs
+++ b/src/Core/Models/Data/InstallationDeviceEntity.cs
@@ -1,8 +1,9 @@
-using Microsoft.Azure.Cosmos.Table;
+using Azure;
+using Azure.Data.Tables;
namespace Bit.Core.Models.Data;
-public class InstallationDeviceEntity : TableEntity
+public class InstallationDeviceEntity : ITableEntity
{
public InstallationDeviceEntity() { }
@@ -27,6 +28,11 @@ public class InstallationDeviceEntity : TableEntity
RowKey = parts[1];
}
+ public string PartitionKey { get; set; }
+ public string RowKey { get; set; }
+ public DateTimeOffset? Timestamp { get; set; }
+ public ETag ETag { get; set; }
+
public static bool IsInstallationDeviceId(string deviceId)
{
return deviceId != null && deviceId.Length == 73 && deviceId[36] == '_';
diff --git a/src/Core/Repositories/TableStorage/EventRepository.cs b/src/Core/Repositories/TableStorage/EventRepository.cs
index 7044850033..7c5cb97dba 100644
--- a/src/Core/Repositories/TableStorage/EventRepository.cs
+++ b/src/Core/Repositories/TableStorage/EventRepository.cs
@@ -1,14 +1,14 @@
-using Bit.Core.Models.Data;
+using Azure.Data.Tables;
+using Bit.Core.Models.Data;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Bit.Core.Vault.Entities;
-using Microsoft.Azure.Cosmos.Table;
namespace Bit.Core.Repositories.TableStorage;
public class EventRepository : IEventRepository
{
- private readonly CloudTable _table;
+ private readonly TableClient _tableClient;
public EventRepository(GlobalSettings globalSettings)
: this(globalSettings.Events.ConnectionString)
@@ -16,9 +16,8 @@ public class EventRepository : IEventRepository
public EventRepository(string storageConnectionString)
{
- var storageAccount = CloudStorageAccount.Parse(storageConnectionString);
- var tableClient = storageAccount.CreateCloudTableClient();
- _table = tableClient.GetTableReference("event");
+ var tableClient = new TableServiceClient(storageConnectionString);
+ _tableClient = tableClient.GetTableClient("event");
}
public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate,
@@ -76,7 +75,7 @@ public class EventRepository : IEventRepository
throw new ArgumentException(nameof(e));
}
- await CreateEntityAsync(entity);
+ await CreateEventAsync(entity);
}
public async Task CreateManyAsync(IEnumerable e)
@@ -99,7 +98,7 @@ public class EventRepository : IEventRepository
var groupEntities = group.ToList();
if (groupEntities.Count == 1)
{
- await CreateEntityAsync(groupEntities.First());
+ await CreateEventAsync(groupEntities.First());
continue;
}
@@ -107,7 +106,7 @@ public class EventRepository : IEventRepository
var iterations = groupEntities.Count / 100;
for (var i = 0; i <= iterations; i++)
{
- var batch = new TableBatchOperation();
+ var batch = new List();
var batchEntities = groupEntities.Skip(i * 100).Take(100);
if (!batchEntities.Any())
{
@@ -116,19 +115,15 @@ public class EventRepository : IEventRepository
foreach (var entity in batchEntities)
{
- batch.InsertOrReplace(entity);
+ batch.Add(new TableTransactionAction(TableTransactionActionType.Add,
+ entity.ToAzureEvent()));
}
- await _table.ExecuteBatchAsync(batch);
+ await _tableClient.SubmitTransactionAsync(batch);
}
}
}
- public async Task CreateEntityAsync(ITableEntity entity)
- {
- await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity));
- }
-
public async Task> GetManyAsync(string partitionKey, string rowKey,
DateTime startDate, DateTime endDate, PageOptions pageOptions)
{
@@ -136,60 +131,28 @@ public class EventRepository : IEventRepository
var end = CoreHelpers.DateTimeToTableStorageKey(endDate);
var filter = MakeFilter(partitionKey, string.Format(rowKey, start), string.Format(rowKey, end));
- var query = new TableQuery().Where(filter).Take(pageOptions.PageSize);
var result = new PagedResult();
- var continuationToken = DeserializeContinuationToken(pageOptions?.ContinuationToken);
+ var query = _tableClient.QueryAsync(filter, pageOptions.PageSize);
- var queryResults = await _table.ExecuteQuerySegmentedAsync(query, continuationToken);
- result.ContinuationToken = SerializeContinuationToken(queryResults.ContinuationToken);
- result.Data.AddRange(queryResults.Results);
+ await using (var enumerator = query.AsPages(pageOptions?.ContinuationToken,
+ pageOptions.PageSize).GetAsyncEnumerator())
+ {
+ await enumerator.MoveNextAsync();
+
+ result.ContinuationToken = enumerator.Current.ContinuationToken;
+ result.Data.AddRange(enumerator.Current.Values.Select(e => e.ToEventTableEntity()));
+ }
return result;
}
+ private async Task CreateEventAsync(EventTableEntity entity)
+ {
+ await _tableClient.UpsertEntityAsync(entity.ToAzureEvent());
+ }
+
private string MakeFilter(string partitionKey, string rowStart, string rowEnd)
{
- var rowFilter = TableQuery.CombineFilters(
- TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.LessThanOrEqual, $"{rowStart}`"),
- TableOperators.And,
- TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.GreaterThanOrEqual, $"{rowEnd}_"));
-
- return TableQuery.CombineFilters(
- TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, partitionKey),
- TableOperators.And,
- rowFilter);
- }
-
- private string SerializeContinuationToken(TableContinuationToken token)
- {
- if (token == null)
- {
- return null;
- }
-
- return string.Format("{0}__{1}__{2}__{3}", (int)token.TargetLocation, token.NextTableName,
- token.NextPartitionKey, token.NextRowKey);
- }
-
- private TableContinuationToken DeserializeContinuationToken(string token)
- {
- if (string.IsNullOrWhiteSpace(token))
- {
- return null;
- }
-
- var tokenParts = token.Split(new string[] { "__" }, StringSplitOptions.None);
- if (tokenParts.Length < 4 || !Enum.TryParse(tokenParts[0], out StorageLocation tLoc))
- {
- return null;
- }
-
- return new TableContinuationToken
- {
- TargetLocation = tLoc,
- NextTableName = string.IsNullOrWhiteSpace(tokenParts[1]) ? null : tokenParts[1],
- NextPartitionKey = string.IsNullOrWhiteSpace(tokenParts[2]) ? null : tokenParts[2],
- NextRowKey = string.IsNullOrWhiteSpace(tokenParts[3]) ? null : tokenParts[3]
- };
+ return $"PartitionKey eq '{partitionKey}' and RowKey le '{rowStart}' and RowKey ge '{rowEnd}'";
}
}
diff --git a/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs b/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs
index 32b466d1b3..2dee07dc2b 100644
--- a/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs
+++ b/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs
@@ -1,13 +1,12 @@
-using System.Net;
+using Azure.Data.Tables;
using Bit.Core.Models.Data;
using Bit.Core.Settings;
-using Microsoft.Azure.Cosmos.Table;
namespace Bit.Core.Repositories.TableStorage;
public class InstallationDeviceRepository : IInstallationDeviceRepository
{
- private readonly CloudTable _table;
+ private readonly TableClient _tableClient;
public InstallationDeviceRepository(GlobalSettings globalSettings)
: this(globalSettings.Events.ConnectionString)
@@ -15,14 +14,13 @@ public class InstallationDeviceRepository : IInstallationDeviceRepository
public InstallationDeviceRepository(string storageConnectionString)
{
- var storageAccount = CloudStorageAccount.Parse(storageConnectionString);
- var tableClient = storageAccount.CreateCloudTableClient();
- _table = tableClient.GetTableReference("installationdevice");
+ var tableClient = new TableServiceClient(storageConnectionString);
+ _tableClient = tableClient.GetTableClient("installationdevice");
}
public async Task UpsertAsync(InstallationDeviceEntity entity)
{
- await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity));
+ await _tableClient.UpsertEntityAsync(entity);
}
public async Task UpsertManyAsync(IList entities)
@@ -52,7 +50,7 @@ public class InstallationDeviceRepository : IInstallationDeviceRepository
var iterations = groupEntities.Count / 100;
for (var i = 0; i <= iterations; i++)
{
- var batch = new TableBatchOperation();
+ var batch = new List();
var batchEntities = groupEntities.Skip(i * 100).Take(100);
if (!batchEntities.Any())
{
@@ -61,24 +59,16 @@ public class InstallationDeviceRepository : IInstallationDeviceRepository
foreach (var entity in batchEntities)
{
- batch.InsertOrReplace(entity);
+ batch.Add(new TableTransactionAction(TableTransactionActionType.UpsertReplace, entity));
}
- await _table.ExecuteBatchAsync(batch);
+ await _tableClient.SubmitTransactionAsync(batch);
}
}
}
public async Task DeleteAsync(InstallationDeviceEntity entity)
{
- try
- {
- entity.ETag = "*";
- await _table.ExecuteAsync(TableOperation.Delete(entity));
- }
- catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound)
- {
- throw;
- }
+ await _tableClient.DeleteEntityAsync(entity.PartitionKey, entity.RowKey);
}
}
From e2cb406a95d4f7dab59a2dbd894b1773e1826fdd Mon Sep 17 00:00:00 2001
From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com>
Date: Fri, 29 Mar 2024 11:18:10 -0400
Subject: [PATCH 10/21] [AC-1910] Allocate seats to a provider organization
(#3936)
* Add endpoint to update a provider organization's seats for consolidated billing.
* Fixed failing tests
---
src/Admin/Startup.cs | 2 +-
.../ProviderOrganizationController.cs | 63 ++++
.../Billing/Models/ProviderSubscriptionDTO.cs | 2 +
.../UpdateProviderOrganizationRequestBody.cs | 6 +
src/Api/Startup.cs | 3 +-
...IAssignSeatsToClientOrganizationCommand.cs | 12 +
.../AssignSeatsToClientOrganizationCommand.cs | 174 +++++++++
src/Core/Billing/Entities/ProviderPlan.cs | 3 +-
.../Billing/Extensions/BillingExtensions.cs | 9 +
.../Extensions/ServiceCollectionExtensions.cs | 16 +-
.../Billing/Models/ConfiguredProviderPlan.cs | 8 +-
.../Queries/IProviderBillingQueries.cs | 15 +-
.../Implementations/ProviderBillingQueries.cs | 47 ++-
.../Business/CompleteSubscriptionUpdate.cs | 26 --
.../Business/ProviderSubscriptionUpdate.cs | 61 ++++
.../Models/Business/SeatSubscriptionUpdate.cs | 4 +-
.../ServiceAccountSubscriptionUpdate.cs | 4 +-
.../Business/SmSeatSubscriptionUpdate.cs | 4 +-
.../SponsorOrganizationSubscriptionUpdate.cs | 8 +-
.../Business/StorageSubscriptionUpdate.cs | 4 +-
.../Models/Business/SubscriptionUpdate.cs | 31 +-
src/Core/Services/IPaymentService.cs | 7 +
.../Implementations/StripePaymentService.cs | 22 +-
src/Core/Utilities/StaticStore.cs | 1 -
.../ProviderBillingControllerTests.cs | 130 +++++++
.../ProviderOrganizationControllerTests.cs | 168 +++++++++
...gnSeatsToClientOrganizationCommandTests.cs | 339 ++++++++++++++++++
.../Queries/ProviderBillingQueriesTests.cs | 7 +-
28 files changed, 1108 insertions(+), 68 deletions(-)
create mode 100644 src/Api/Billing/Controllers/ProviderOrganizationController.cs
create mode 100644 src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs
create mode 100644 src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs
create mode 100644 src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs
create mode 100644 src/Core/Billing/Extensions/BillingExtensions.cs
create mode 100644 src/Core/Models/Business/ProviderSubscriptionUpdate.cs
create mode 100644 test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs
create mode 100644 test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs
create mode 100644 test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs
diff --git a/src/Admin/Startup.cs b/src/Admin/Startup.cs
index db870266cc..788908d42a 100644
--- a/src/Admin/Startup.cs
+++ b/src/Admin/Startup.cs
@@ -88,7 +88,7 @@ public class Startup
services.AddBaseServices(globalSettings);
services.AddDefaultServices(globalSettings);
services.AddScoped();
- services.AddBillingCommands();
+ services.AddBillingOperations();
#if OSS
services.AddOosServices();
diff --git a/src/Api/Billing/Controllers/ProviderOrganizationController.cs b/src/Api/Billing/Controllers/ProviderOrganizationController.cs
new file mode 100644
index 0000000000..8760415f5e
--- /dev/null
+++ b/src/Api/Billing/Controllers/ProviderOrganizationController.cs
@@ -0,0 +1,63 @@
+using Bit.Api.Billing.Models;
+using Bit.Core;
+using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.Billing.Commands;
+using Bit.Core.Context;
+using Bit.Core.Repositories;
+using Bit.Core.Services;
+using Microsoft.AspNetCore.Mvc;
+
+namespace Bit.Api.Billing.Controllers;
+
+[Route("providers/{providerId:guid}/organizations")]
+public class ProviderOrganizationController(
+ IAssignSeatsToClientOrganizationCommand assignSeatsToClientOrganizationCommand,
+ ICurrentContext currentContext,
+ IFeatureService featureService,
+ ILogger logger,
+ IOrganizationRepository organizationRepository,
+ IProviderRepository providerRepository,
+ IProviderOrganizationRepository providerOrganizationRepository) : Controller
+{
+ [HttpPut("{providerOrganizationId:guid}")]
+ public async Task UpdateAsync(
+ [FromRoute] Guid providerId,
+ [FromRoute] Guid providerOrganizationId,
+ [FromBody] UpdateProviderOrganizationRequestBody requestBody)
+ {
+ if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling))
+ {
+ return TypedResults.NotFound();
+ }
+
+ if (!currentContext.ProviderProviderAdmin(providerId))
+ {
+ return TypedResults.Unauthorized();
+ }
+
+ var provider = await providerRepository.GetByIdAsync(providerId);
+
+ var providerOrganization = await providerOrganizationRepository.GetByIdAsync(providerOrganizationId);
+
+ if (provider == null || providerOrganization == null)
+ {
+ return TypedResults.NotFound();
+ }
+
+ var organization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId);
+
+ if (organization == null)
+ {
+ logger.LogError("The organization ({OrganizationID}) represented by provider organization ({ProviderOrganizationID}) could not be found.", providerOrganization.OrganizationId, providerOrganization.Id);
+
+ return TypedResults.Problem();
+ }
+
+ await assignSeatsToClientOrganizationCommand.AssignSeatsToClientOrganization(
+ provider,
+ organization,
+ requestBody.AssignedSeats);
+
+ return TypedResults.NoContent();
+ }
+}
diff --git a/src/Api/Billing/Models/ProviderSubscriptionDTO.cs b/src/Api/Billing/Models/ProviderSubscriptionDTO.cs
index 0e8b8bfb1c..ad0714967d 100644
--- a/src/Api/Billing/Models/ProviderSubscriptionDTO.cs
+++ b/src/Api/Billing/Models/ProviderSubscriptionDTO.cs
@@ -27,6 +27,7 @@ public record ProviderSubscriptionDTO(
plan.Name,
providerPlan.SeatMinimum,
providerPlan.PurchasedSeats,
+ providerPlan.AssignedSeats,
cost,
cadence);
});
@@ -43,5 +44,6 @@ public record ProviderPlanDTO(
string PlanName,
int SeatMinimum,
int PurchasedSeats,
+ int AssignedSeats,
decimal Cost,
string Cadence);
diff --git a/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs b/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs
new file mode 100644
index 0000000000..7bac8fdef4
--- /dev/null
+++ b/src/Api/Billing/Models/UpdateProviderOrganizationRequestBody.cs
@@ -0,0 +1,6 @@
+namespace Bit.Api.Billing.Models;
+
+public class UpdateProviderOrganizationRequestBody
+{
+ public int AssignedSeats { get; set; }
+}
diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs
index 9f94325513..63b1a3c3cd 100644
--- a/src/Api/Startup.cs
+++ b/src/Api/Startup.cs
@@ -170,8 +170,7 @@ public class Startup
services.AddDefaultServices(globalSettings);
services.AddOrganizationSubscriptionServices();
services.AddCoreLocalizationServices();
- services.AddBillingCommands();
- services.AddBillingQueries();
+ services.AddBillingOperations();
// Authorization Handlers
services.AddAuthorizationHandlers();
diff --git a/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs b/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs
new file mode 100644
index 0000000000..db21926bec
--- /dev/null
+++ b/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs
@@ -0,0 +1,12 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
+
+namespace Bit.Core.Billing.Commands;
+
+public interface IAssignSeatsToClientOrganizationCommand
+{
+ Task AssignSeatsToClientOrganization(
+ Provider provider,
+ Organization organization,
+ int seats);
+}
diff --git a/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs b/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs
new file mode 100644
index 0000000000..be2c6be968
--- /dev/null
+++ b/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs
@@ -0,0 +1,174 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.AdminConsole.Enums.Provider;
+using Bit.Core.Billing.Entities;
+using Bit.Core.Billing.Extensions;
+using Bit.Core.Billing.Queries;
+using Bit.Core.Billing.Repositories;
+using Bit.Core.Repositories;
+using Bit.Core.Services;
+using Bit.Core.Utilities;
+using Microsoft.Extensions.Logging;
+using static Bit.Core.Billing.Utilities;
+
+namespace Bit.Core.Billing.Commands.Implementations;
+
+public class AssignSeatsToClientOrganizationCommand(
+ ILogger logger,
+ IOrganizationRepository organizationRepository,
+ IPaymentService paymentService,
+ IProviderBillingQueries providerBillingQueries,
+ IProviderPlanRepository providerPlanRepository) : IAssignSeatsToClientOrganizationCommand
+{
+ public async Task AssignSeatsToClientOrganization(
+ Provider provider,
+ Organization organization,
+ int seats)
+ {
+ ArgumentNullException.ThrowIfNull(provider);
+ ArgumentNullException.ThrowIfNull(organization);
+
+ if (provider.Type == ProviderType.Reseller)
+ {
+ logger.LogError("Reseller-type provider ({ID}) cannot assign seats to client organizations", provider.Id);
+
+ throw ContactSupport("Consolidated billing does not support reseller-type providers");
+ }
+
+ if (seats < 0)
+ {
+ throw new BillingException(
+ "You cannot assign negative seats to a client.",
+ "MSP cannot assign negative seats to a client organization");
+ }
+
+ if (seats == organization.Seats)
+ {
+ logger.LogWarning("Client organization ({ID}) already has {Seats} seats assigned", organization.Id, organization.Seats);
+
+ return;
+ }
+
+ var providerPlan = await GetProviderPlanAsync(provider, organization);
+
+ var providerSeatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0);
+
+ // How many seats the provider has assigned to all their client organizations that have the specified plan type.
+ var providerCurrentlyAssignedSeatTotal = await providerBillingQueries.GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType);
+
+ // How many seats are being added to or subtracted from this client organization.
+ var seatDifference = seats - (organization.Seats ?? 0);
+
+ // How many seats the provider will have assigned to all of their client organizations after the update.
+ var providerNewlyAssignedSeatTotal = providerCurrentlyAssignedSeatTotal + seatDifference;
+
+ var update = CurryUpdateFunction(
+ provider,
+ providerPlan,
+ organization,
+ seats,
+ providerNewlyAssignedSeatTotal);
+
+ /*
+ * Below the limit => Below the limit:
+ * No subscription update required. We can safely update the organization's seats.
+ */
+ if (providerCurrentlyAssignedSeatTotal <= providerSeatMinimum &&
+ providerNewlyAssignedSeatTotal <= providerSeatMinimum)
+ {
+ organization.Seats = seats;
+
+ await organizationRepository.ReplaceAsync(organization);
+
+ providerPlan.AllocatedSeats = providerNewlyAssignedSeatTotal;
+
+ await providerPlanRepository.ReplaceAsync(providerPlan);
+ }
+ /*
+ * Below the limit => Above the limit:
+ * We have to scale the subscription up from the seat minimum to the newly assigned seat total.
+ */
+ else if (providerCurrentlyAssignedSeatTotal <= providerSeatMinimum &&
+ providerNewlyAssignedSeatTotal > providerSeatMinimum)
+ {
+ await update(
+ providerSeatMinimum,
+ providerNewlyAssignedSeatTotal);
+ }
+ /*
+ * Above the limit => Above the limit:
+ * We have to scale the subscription from the currently assigned seat total to the newly assigned seat total.
+ */
+ else if (providerCurrentlyAssignedSeatTotal > providerSeatMinimum &&
+ providerNewlyAssignedSeatTotal > providerSeatMinimum)
+ {
+ await update(
+ providerCurrentlyAssignedSeatTotal,
+ providerNewlyAssignedSeatTotal);
+ }
+ /*
+ * Above the limit => Below the limit:
+ * We have to scale the subscription down from the currently assigned seat total to the seat minimum.
+ */
+ else if (providerCurrentlyAssignedSeatTotal > providerSeatMinimum &&
+ providerNewlyAssignedSeatTotal <= providerSeatMinimum)
+ {
+ await update(
+ providerCurrentlyAssignedSeatTotal,
+ providerSeatMinimum);
+ }
+ }
+
+ // ReSharper disable once SuggestBaseTypeForParameter
+ private async Task GetProviderPlanAsync(Provider provider, Organization organization)
+ {
+ if (!organization.PlanType.SupportsConsolidatedBilling())
+ {
+ logger.LogError("Cannot assign seats to a client organization ({ID}) with a plan type that does not support consolidated billing: {PlanType}", organization.Id, organization.PlanType);
+
+ throw ContactSupport();
+ }
+
+ var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
+
+ var providerPlan = providerPlans.FirstOrDefault(providerPlan => providerPlan.PlanType == organization.PlanType);
+
+ if (providerPlan != null && providerPlan.IsConfigured())
+ {
+ return providerPlan;
+ }
+
+ logger.LogError("Cannot assign seats to client organization ({ClientOrganizationID}) when provider's ({ProviderID}) matching plan is not configured", organization.Id, provider.Id);
+
+ throw ContactSupport();
+ }
+
+ private Func CurryUpdateFunction(
+ Provider provider,
+ ProviderPlan providerPlan,
+ Organization organization,
+ int organizationNewlyAssignedSeats,
+ int providerNewlyAssignedSeats) => async (providerCurrentlySubscribedSeats, providerNewlySubscribedSeats) =>
+ {
+ var plan = StaticStore.GetPlan(providerPlan.PlanType);
+
+ await paymentService.AdjustSeats(
+ provider,
+ plan,
+ providerCurrentlySubscribedSeats,
+ providerNewlySubscribedSeats);
+
+ organization.Seats = organizationNewlyAssignedSeats;
+
+ await organizationRepository.ReplaceAsync(organization);
+
+ var providerNewlyPurchasedSeats = providerNewlySubscribedSeats > providerPlan.SeatMinimum
+ ? providerNewlySubscribedSeats - providerPlan.SeatMinimum
+ : 0;
+
+ providerPlan.PurchasedSeats = providerNewlyPurchasedSeats;
+ providerPlan.AllocatedSeats = providerNewlyAssignedSeats;
+
+ await providerPlanRepository.ReplaceAsync(providerPlan);
+ };
+}
diff --git a/src/Core/Billing/Entities/ProviderPlan.cs b/src/Core/Billing/Entities/ProviderPlan.cs
index 2f15a539e1..f4965570d9 100644
--- a/src/Core/Billing/Entities/ProviderPlan.cs
+++ b/src/Core/Billing/Entities/ProviderPlan.cs
@@ -11,6 +11,7 @@ public class ProviderPlan : ITableObject
public PlanType PlanType { get; set; }
public int? SeatMinimum { get; set; }
public int? PurchasedSeats { get; set; }
+ public int? AllocatedSeats { get; set; }
public void SetNewId()
{
@@ -20,5 +21,5 @@ public class ProviderPlan : ITableObject
}
}
- public bool Configured => SeatMinimum.HasValue && PurchasedSeats.HasValue;
+ public bool IsConfigured() => SeatMinimum.HasValue && PurchasedSeats.HasValue && AllocatedSeats.HasValue;
}
diff --git a/src/Core/Billing/Extensions/BillingExtensions.cs b/src/Core/Billing/Extensions/BillingExtensions.cs
new file mode 100644
index 0000000000..c7abeb81e2
--- /dev/null
+++ b/src/Core/Billing/Extensions/BillingExtensions.cs
@@ -0,0 +1,9 @@
+using Bit.Core.Enums;
+
+namespace Bit.Core.Billing.Extensions;
+
+public static class BillingExtensions
+{
+ public static bool SupportsConsolidatedBilling(this PlanType planType)
+ => planType is PlanType.TeamsMonthly or PlanType.EnterpriseMonthly;
+}
diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs
index 751bfdb671..8e28b23397 100644
--- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs
+++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs
@@ -9,15 +9,15 @@ using Microsoft.Extensions.DependencyInjection;
public static class ServiceCollectionExtensions
{
- public static void AddBillingCommands(this IServiceCollection services)
+ public static void AddBillingOperations(this IServiceCollection services)
{
- services.AddSingleton();
- services.AddSingleton();
- }
+ // Queries
+ services.AddTransient();
+ services.AddTransient();
- public static void AddBillingQueries(this IServiceCollection services)
- {
- services.AddSingleton();
- services.AddSingleton();
+ // Commands
+ services.AddTransient();
+ services.AddTransient();
+ services.AddTransient();
}
}
diff --git a/src/Core/Billing/Models/ConfiguredProviderPlan.cs b/src/Core/Billing/Models/ConfiguredProviderPlan.cs
index d5d53b36fa..d6bc2b7522 100644
--- a/src/Core/Billing/Models/ConfiguredProviderPlan.cs
+++ b/src/Core/Billing/Models/ConfiguredProviderPlan.cs
@@ -8,15 +8,17 @@ public record ConfiguredProviderPlan(
Guid ProviderId,
PlanType PlanType,
int SeatMinimum,
- int PurchasedSeats)
+ int PurchasedSeats,
+ int AssignedSeats)
{
public static ConfiguredProviderPlan From(ProviderPlan providerPlan) =>
- providerPlan.Configured
+ providerPlan.IsConfigured()
? new ConfiguredProviderPlan(
providerPlan.Id,
providerPlan.ProviderId,
providerPlan.PlanType,
providerPlan.SeatMinimum.GetValueOrDefault(0),
- providerPlan.PurchasedSeats.GetValueOrDefault(0))
+ providerPlan.PurchasedSeats.GetValueOrDefault(0),
+ providerPlan.AllocatedSeats.GetValueOrDefault(0))
: null;
}
diff --git a/src/Core/Billing/Queries/IProviderBillingQueries.cs b/src/Core/Billing/Queries/IProviderBillingQueries.cs
index 1edfddaf56..e4b7d0f14d 100644
--- a/src/Core/Billing/Queries/IProviderBillingQueries.cs
+++ b/src/Core/Billing/Queries/IProviderBillingQueries.cs
@@ -1,9 +1,22 @@
-using Bit.Core.Billing.Models;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.AdminConsole.Enums.Provider;
+using Bit.Core.Billing.Models;
+using Bit.Core.Enums;
namespace Bit.Core.Billing.Queries;
public interface IProviderBillingQueries
{
+ ///
+ /// Retrieves the number of seats an MSP has assigned to its client organizations with a specified .
+ ///
+ /// The ID of the MSP to retrieve the assigned seat total for.
+ /// The type of plan to retrieve the assigned seat total for.
+ /// An representing the number of seats the provider has assigned to its client organizations with the specified .
+ /// Thrown when the provider represented by the is .
+ /// Thrown when the provider represented by the has .
+ Task GetAssignedSeatTotalForPlanOrThrow(Guid providerId, PlanType planType);
+
///
/// Retrieves a provider's billing subscription data.
///
diff --git a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs
index c921e82969..f8bff9d3fd 100644
--- a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs
+++ b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs
@@ -1,17 +1,53 @@
-using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.AdminConsole.Enums.Provider;
+using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Repositories;
+using Bit.Core.Enums;
+using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
using Stripe;
+using static Bit.Core.Billing.Utilities;
namespace Bit.Core.Billing.Queries.Implementations;
public class ProviderBillingQueries(
ILogger logger,
+ IProviderOrganizationRepository providerOrganizationRepository,
IProviderPlanRepository providerPlanRepository,
IProviderRepository providerRepository,
ISubscriberQueries subscriberQueries) : IProviderBillingQueries
{
+ public async Task GetAssignedSeatTotalForPlanOrThrow(
+ Guid providerId,
+ PlanType planType)
+ {
+ var provider = await providerRepository.GetByIdAsync(providerId);
+
+ if (provider == null)
+ {
+ logger.LogError(
+ "Could not find provider ({ID}) when retrieving assigned seat total",
+ providerId);
+
+ throw ContactSupport();
+ }
+
+ if (provider.Type == ProviderType.Reseller)
+ {
+ logger.LogError("Assigned seats cannot be retrieved for reseller-type provider ({ID})", providerId);
+
+ throw ContactSupport("Consolidated billing does not support reseller-type providers");
+ }
+
+ var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId);
+
+ var plan = StaticStore.GetPlan(planType);
+
+ return providerOrganizations
+ .Where(providerOrganization => providerOrganization.Plan == plan.Name)
+ .Sum(providerOrganization => providerOrganization.Seats ?? 0);
+ }
+
public async Task GetSubscriptionData(Guid providerId)
{
var provider = await providerRepository.GetByIdAsync(providerId);
@@ -25,6 +61,13 @@ public class ProviderBillingQueries(
return null;
}
+ if (provider.Type == ProviderType.Reseller)
+ {
+ logger.LogError("Subscription data cannot be retrieved for reseller-type provider ({ID})", providerId);
+
+ throw ContactSupport("Consolidated billing does not support reseller-type providers");
+ }
+
var subscription = await subscriberQueries.GetSubscription(provider, new SubscriptionGetOptions
{
Expand = ["customer"]
@@ -38,7 +81,7 @@ public class ProviderBillingQueries(
var providerPlans = await providerPlanRepository.GetByProviderId(providerId);
var configuredProviderPlans = providerPlans
- .Where(providerPlan => providerPlan.Configured)
+ .Where(providerPlan => providerPlan.IsConfigured())
.Select(ConfiguredProviderPlan.From)
.ToList();
diff --git a/src/Core/Models/Business/CompleteSubscriptionUpdate.cs b/src/Core/Models/Business/CompleteSubscriptionUpdate.cs
index a1146cd2a0..aa1c92dc2e 100644
--- a/src/Core/Models/Business/CompleteSubscriptionUpdate.cs
+++ b/src/Core/Models/Business/CompleteSubscriptionUpdate.cs
@@ -1,5 +1,4 @@
using Bit.Core.AdminConsole.Entities;
-using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Stripe;
@@ -279,25 +278,6 @@ public class CompleteSubscriptionUpdate : SubscriptionUpdate
};
}
- private static SubscriptionItem FindSubscriptionItem(Subscription subscription, string planId)
- {
- if (string.IsNullOrEmpty(planId))
- {
- return null;
- }
-
- var data = subscription.Items.Data;
-
- var subscriptionItem = data.FirstOrDefault(item => item.Plan?.Id == planId) ?? data.FirstOrDefault(item => item.Price?.Id == planId);
-
- return subscriptionItem;
- }
-
- private static string GetPasswordManagerPlanId(StaticStore.Plan plan)
- => IsNonSeatBasedPlan(plan)
- ? plan.PasswordManager.StripePlanId
- : plan.PasswordManager.StripeSeatPlanId;
-
private static SubscriptionData GetSubscriptionDataFor(Organization organization)
{
var plan = Utilities.StaticStore.GetPlan(organization.PlanType);
@@ -320,10 +300,4 @@ public class CompleteSubscriptionUpdate : SubscriptionUpdate
0
};
}
-
- private static bool IsNonSeatBasedPlan(StaticStore.Plan plan)
- => plan.Type is
- >= PlanType.FamiliesAnnually2019 and <= PlanType.EnterpriseAnnually2019
- or PlanType.FamiliesAnnually
- or PlanType.TeamsStarter;
}
diff --git a/src/Core/Models/Business/ProviderSubscriptionUpdate.cs b/src/Core/Models/Business/ProviderSubscriptionUpdate.cs
new file mode 100644
index 0000000000..8b29bebce5
--- /dev/null
+++ b/src/Core/Models/Business/ProviderSubscriptionUpdate.cs
@@ -0,0 +1,61 @@
+using Bit.Core.Billing.Extensions;
+using Bit.Core.Enums;
+using Stripe;
+
+using static Bit.Core.Billing.Utilities;
+
+namespace Bit.Core.Models.Business;
+
+public class ProviderSubscriptionUpdate : SubscriptionUpdate
+{
+ private readonly string _planId;
+ private readonly int _previouslyPurchasedSeats;
+ private readonly int _newlyPurchasedSeats;
+
+ protected override List PlanIds => [_planId];
+
+ public ProviderSubscriptionUpdate(
+ PlanType planType,
+ int previouslyPurchasedSeats,
+ int newlyPurchasedSeats)
+ {
+ if (!planType.SupportsConsolidatedBilling())
+ {
+ throw ContactSupport($"Cannot create a {nameof(ProviderSubscriptionUpdate)} for {nameof(PlanType)} that doesn't support consolidated billing");
+ }
+
+ _planId = GetPasswordManagerPlanId(Utilities.StaticStore.GetPlan(planType));
+ _previouslyPurchasedSeats = previouslyPurchasedSeats;
+ _newlyPurchasedSeats = newlyPurchasedSeats;
+ }
+
+ public override List RevertItemsOptions(Subscription subscription)
+ {
+ var subscriptionItem = FindSubscriptionItem(subscription, _planId);
+
+ return
+ [
+ new SubscriptionItemOptions
+ {
+ Id = subscriptionItem.Id,
+ Price = _planId,
+ Quantity = _previouslyPurchasedSeats
+ }
+ ];
+ }
+
+ public override List UpgradeItemsOptions(Subscription subscription)
+ {
+ var subscriptionItem = FindSubscriptionItem(subscription, _planId);
+
+ return
+ [
+ new SubscriptionItemOptions
+ {
+ Id = subscriptionItem.Id,
+ Price = _planId,
+ Quantity = _newlyPurchasedSeats
+ }
+ ];
+ }
+}
diff --git a/src/Core/Models/Business/SeatSubscriptionUpdate.cs b/src/Core/Models/Business/SeatSubscriptionUpdate.cs
index c5ea1a7474..db5104ddd2 100644
--- a/src/Core/Models/Business/SeatSubscriptionUpdate.cs
+++ b/src/Core/Models/Business/SeatSubscriptionUpdate.cs
@@ -18,7 +18,7 @@ public class SeatSubscriptionUpdate : SubscriptionUpdate
public override List UpgradeItemsOptions(Subscription subscription)
{
- var item = SubscriptionItem(subscription, PlanIds.Single());
+ var item = FindSubscriptionItem(subscription, PlanIds.Single());
return new()
{
new SubscriptionItemOptions
@@ -34,7 +34,7 @@ public class SeatSubscriptionUpdate : SubscriptionUpdate
public override List RevertItemsOptions(Subscription subscription)
{
- var item = SubscriptionItem(subscription, PlanIds.Single());
+ var item = FindSubscriptionItem(subscription, PlanIds.Single());
return new()
{
new SubscriptionItemOptions
diff --git a/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs b/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs
index c93212eac8..c3e3e09992 100644
--- a/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs
+++ b/src/Core/Models/Business/ServiceAccountSubscriptionUpdate.cs
@@ -19,7 +19,7 @@ public class ServiceAccountSubscriptionUpdate : SubscriptionUpdate
public override List UpgradeItemsOptions(Subscription subscription)
{
- var item = SubscriptionItem(subscription, PlanIds.Single());
+ var item = FindSubscriptionItem(subscription, PlanIds.Single());
_prevServiceAccounts = item?.Quantity ?? 0;
return new()
{
@@ -35,7 +35,7 @@ public class ServiceAccountSubscriptionUpdate : SubscriptionUpdate
public override List RevertItemsOptions(Subscription subscription)
{
- var item = SubscriptionItem(subscription, PlanIds.Single());
+ var item = FindSubscriptionItem(subscription, PlanIds.Single());
return new()
{
new SubscriptionItemOptions
diff --git a/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs b/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs
index ff6bb55011..b8201b9775 100644
--- a/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs
+++ b/src/Core/Models/Business/SmSeatSubscriptionUpdate.cs
@@ -19,7 +19,7 @@ public class SmSeatSubscriptionUpdate : SubscriptionUpdate
public override List UpgradeItemsOptions(Subscription subscription)
{
- var item = SubscriptionItem(subscription, PlanIds.Single());
+ var item = FindSubscriptionItem(subscription, PlanIds.Single());
return new()
{
new SubscriptionItemOptions
@@ -35,7 +35,7 @@ public class SmSeatSubscriptionUpdate : SubscriptionUpdate
public override List RevertItemsOptions(Subscription subscription)
{
- var item = SubscriptionItem(subscription, PlanIds.Single());
+ var item = FindSubscriptionItem(subscription, PlanIds.Single());
return new()
{
new SubscriptionItemOptions
diff --git a/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs b/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs
index 88af72f199..59a745297b 100644
--- a/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs
+++ b/src/Core/Models/Business/SponsorOrganizationSubscriptionUpdate.cs
@@ -74,10 +74,10 @@ public class SponsorOrganizationSubscriptionUpdate : SubscriptionUpdate
private string AddStripePlanId => _applySponsorship ? _sponsoredPlanStripeId : _existingPlanStripeId;
private Stripe.SubscriptionItem RemoveStripeItem(Subscription subscription) =>
_applySponsorship ?
- SubscriptionItem(subscription, _existingPlanStripeId) :
- SubscriptionItem(subscription, _sponsoredPlanStripeId);
+ FindSubscriptionItem(subscription, _existingPlanStripeId) :
+ FindSubscriptionItem(subscription, _sponsoredPlanStripeId);
private Stripe.SubscriptionItem AddStripeItem(Subscription subscription) =>
_applySponsorship ?
- SubscriptionItem(subscription, _sponsoredPlanStripeId) :
- SubscriptionItem(subscription, _existingPlanStripeId);
+ FindSubscriptionItem(subscription, _sponsoredPlanStripeId) :
+ FindSubscriptionItem(subscription, _existingPlanStripeId);
}
diff --git a/src/Core/Models/Business/StorageSubscriptionUpdate.cs b/src/Core/Models/Business/StorageSubscriptionUpdate.cs
index 30ab2428e2..b0f4a83d3e 100644
--- a/src/Core/Models/Business/StorageSubscriptionUpdate.cs
+++ b/src/Core/Models/Business/StorageSubscriptionUpdate.cs
@@ -17,7 +17,7 @@ public class StorageSubscriptionUpdate : SubscriptionUpdate
public override List UpgradeItemsOptions(Subscription subscription)
{
- var item = SubscriptionItem(subscription, PlanIds.Single());
+ var item = FindSubscriptionItem(subscription, PlanIds.Single());
_prevStorage = item?.Quantity ?? 0;
return new()
{
@@ -38,7 +38,7 @@ public class StorageSubscriptionUpdate : SubscriptionUpdate
throw new Exception("Unknown previous value, must first call UpgradeItemsOptions");
}
- var item = SubscriptionItem(subscription, PlanIds.Single());
+ var item = FindSubscriptionItem(subscription, PlanIds.Single());
return new()
{
new SubscriptionItemOptions
diff --git a/src/Core/Models/Business/SubscriptionUpdate.cs b/src/Core/Models/Business/SubscriptionUpdate.cs
index 70106a10ea..bba9d384d2 100644
--- a/src/Core/Models/Business/SubscriptionUpdate.cs
+++ b/src/Core/Models/Business/SubscriptionUpdate.cs
@@ -1,4 +1,5 @@
-using Stripe;
+using Bit.Core.Enums;
+using Stripe;
namespace Bit.Core.Models.Business;
@@ -15,7 +16,7 @@ public abstract class SubscriptionUpdate
foreach (var upgradeItemOptions in upgradeItemsOptions)
{
var upgradeQuantity = upgradeItemOptions.Quantity ?? 0;
- var existingQuantity = SubscriptionItem(subscription, upgradeItemOptions.Plan)?.Quantity ?? 0;
+ var existingQuantity = FindSubscriptionItem(subscription, upgradeItemOptions.Plan)?.Quantity ?? 0;
if (upgradeQuantity != existingQuantity)
{
return true;
@@ -24,6 +25,28 @@ public abstract class SubscriptionUpdate
return false;
}
- protected static SubscriptionItem SubscriptionItem(Subscription subscription, string planId) =>
- planId == null ? null : subscription.Items?.Data?.FirstOrDefault(i => i.Plan.Id == planId);
+ protected static SubscriptionItem FindSubscriptionItem(Subscription subscription, string planId)
+ {
+ if (string.IsNullOrEmpty(planId))
+ {
+ return null;
+ }
+
+ var data = subscription.Items.Data;
+
+ var subscriptionItem = data.FirstOrDefault(item => item.Plan?.Id == planId) ?? data.FirstOrDefault(item => item.Price?.Id == planId);
+
+ return subscriptionItem;
+ }
+
+ protected static string GetPasswordManagerPlanId(StaticStore.Plan plan)
+ => IsNonSeatBasedPlan(plan)
+ ? plan.PasswordManager.StripePlanId
+ : plan.PasswordManager.StripeSeatPlanId;
+
+ protected static bool IsNonSeatBasedPlan(StaticStore.Plan plan)
+ => plan.Type is
+ >= PlanType.FamiliesAnnually2019 and <= PlanType.EnterpriseAnnually2019
+ or PlanType.FamiliesAnnually
+ or PlanType.TeamsStarter;
}
diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs
index f8f24cfbdb..e0d2e95dc9 100644
--- a/src/Core/Services/IPaymentService.cs
+++ b/src/Core/Services/IPaymentService.cs
@@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Business;
@@ -28,6 +29,12 @@ public interface IPaymentService
int newlyPurchasedAdditionalStorage,
DateTime? prorationDate = null);
Task AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats, DateTime? prorationDate = null);
+ Task AdjustSeats(
+ Provider provider,
+ Plan plan,
+ int currentlySubscribedSeats,
+ int newlySubscribedSeats,
+ DateTime? prorationDate = null);
Task AdjustSmSeatsAsync(Organization organization, Plan plan, int additionalSeats, DateTime? prorationDate = null);
Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId, DateTime? prorationDate = null);
diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs
index 19437a1ee2..e89bdacfe1 100644
--- a/src/Core/Services/Implementations/StripePaymentService.cs
+++ b/src/Core/Services/Implementations/StripePaymentService.cs
@@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Billing.Constants;
using Bit.Core.Entities;
using Bit.Core.Enums;
@@ -757,14 +758,14 @@ public class StripePaymentService : IPaymentService
}).ToList();
}
- private async Task FinalizeSubscriptionChangeAsync(IStorableSubscriber storableSubscriber,
+ private async Task FinalizeSubscriptionChangeAsync(ISubscriber subscriber,
SubscriptionUpdate subscriptionUpdate, DateTime? prorationDate, bool invoiceNow = false)
{
// remember, when in doubt, throw
var subGetOptions = new SubscriptionGetOptions();
// subGetOptions.AddExpand("customer");
subGetOptions.AddExpand("customer.tax");
- var sub = await _stripeAdapter.SubscriptionGetAsync(storableSubscriber.GatewaySubscriptionId, subGetOptions);
+ var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subGetOptions);
if (sub == null)
{
throw new GatewayException("Subscription not found.");
@@ -792,8 +793,8 @@ public class StripePaymentService : IPaymentService
{
var upcomingInvoiceWithChanges = await _stripeAdapter.InvoiceUpcomingAsync(new UpcomingInvoiceOptions
{
- Customer = storableSubscriber.GatewayCustomerId,
- Subscription = storableSubscriber.GatewaySubscriptionId,
+ Customer = subscriber.GatewayCustomerId,
+ Subscription = subscriber.GatewaySubscriptionId,
SubscriptionItems = ToInvoiceSubscriptionItemOptions(updatedItemOptions),
SubscriptionProrationBehavior = Constants.CreateProrations,
SubscriptionProrationDate = prorationDate,
@@ -862,7 +863,7 @@ public class StripePaymentService : IPaymentService
{
if (chargeNow)
{
- paymentIntentClientSecret = await PayInvoiceAfterSubscriptionChangeAsync(storableSubscriber, invoice);
+ paymentIntentClientSecret = await PayInvoiceAfterSubscriptionChangeAsync(subscriber, invoice);
}
else
{
@@ -943,6 +944,17 @@ public class StripePaymentService : IPaymentService
return FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats), prorationDate);
}
+ public Task AdjustSeats(
+ Provider provider,
+ StaticStore.Plan plan,
+ int currentlySubscribedSeats,
+ int newlySubscribedSeats,
+ DateTime? prorationDate = null)
+ => FinalizeSubscriptionChangeAsync(
+ provider,
+ new ProviderSubscriptionUpdate(plan.Type, currentlySubscribedSeats, newlySubscribedSeats),
+ prorationDate);
+
public Task AdjustSmSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats, DateTime? prorationDate = null)
{
return FinalizeSubscriptionChangeAsync(organization, new SmSeatSubscriptionUpdate(organization, plan, additionalSeats), prorationDate);
diff --git a/src/Core/Utilities/StaticStore.cs b/src/Core/Utilities/StaticStore.cs
index dcf63df138..007f3374e0 100644
--- a/src/Core/Utilities/StaticStore.cs
+++ b/src/Core/Utilities/StaticStore.cs
@@ -147,7 +147,6 @@ public static class StaticStore
public static Plan GetPlan(PlanType planType) => Plans.SingleOrDefault(p => p.Type == planType);
-
public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) =>
SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType);
diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs
new file mode 100644
index 0000000000..57480ac116
--- /dev/null
+++ b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs
@@ -0,0 +1,130 @@
+using Bit.Api.Billing.Controllers;
+using Bit.Api.Billing.Models;
+using Bit.Core;
+using Bit.Core.Billing.Models;
+using Bit.Core.Billing.Queries;
+using Bit.Core.Context;
+using Bit.Core.Enums;
+using Bit.Core.Services;
+using Bit.Core.Utilities;
+using Bit.Test.Common.AutoFixture;
+using Bit.Test.Common.AutoFixture.Attributes;
+using Microsoft.AspNetCore.Http.HttpResults;
+using NSubstitute;
+using NSubstitute.ReturnsExtensions;
+using Stripe;
+using Xunit;
+
+namespace Bit.Api.Test.Billing.Controllers;
+
+[ControllerCustomize(typeof(ProviderBillingController))]
+[SutProviderCustomize]
+public class ProviderBillingControllerTests
+{
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionAsync_FFDisabled_NotFound(
+ Guid providerId,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
+ .Returns(false);
+
+ var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
+
+ Assert.IsType(result);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized(
+ Guid providerId,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
+ .Returns(true);
+
+ sutProvider.GetDependency().ProviderProviderAdmin(providerId)
+ .Returns(false);
+
+ var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
+
+ Assert.IsType(result);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionAsync_NoSubscriptionData_NotFound(
+ Guid providerId,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
+ .Returns(true);
+
+ sutProvider.GetDependency().ProviderProviderAdmin(providerId)
+ .Returns(true);
+
+ sutProvider.GetDependency().GetSubscriptionData(providerId).ReturnsNull();
+
+ var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
+
+ Assert.IsType(result);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionAsync_OK(
+ Guid providerId,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
+ .Returns(true);
+
+ sutProvider.GetDependency().ProviderProviderAdmin(providerId)
+ .Returns(true);
+
+ var configuredPlans = new List
+ {
+ new (Guid.NewGuid(), providerId, PlanType.TeamsMonthly, 50, 10, 30),
+ new (Guid.NewGuid(), providerId, PlanType.EnterpriseMonthly, 100, 0, 90)
+ };
+
+ var subscription = new Subscription
+ {
+ Status = "active",
+ CurrentPeriodEnd = new DateTime(2025, 1, 1),
+ Customer = new Customer { Discount = new Discount { Coupon = new Coupon { PercentOff = 10 } } }
+ };
+
+ var providerSubscriptionData = new ProviderSubscriptionData(
+ configuredPlans,
+ subscription);
+
+ sutProvider.GetDependency().GetSubscriptionData(providerId)
+ .Returns(providerSubscriptionData);
+
+ var result = await sutProvider.Sut.GetSubscriptionAsync(providerId);
+
+ Assert.IsType>(result);
+
+ var providerSubscriptionDTO = ((Ok)result).Value;
+
+ Assert.Equal(providerSubscriptionDTO.Status, subscription.Status);
+ Assert.Equal(providerSubscriptionDTO.CurrentPeriodEndDate, subscription.CurrentPeriodEnd);
+ Assert.Equal(providerSubscriptionDTO.DiscountPercentage, subscription.Customer!.Discount!.Coupon!.PercentOff);
+
+ var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
+ var providerTeamsPlan = providerSubscriptionDTO.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name);
+ Assert.NotNull(providerTeamsPlan);
+ Assert.Equal(50, providerTeamsPlan.SeatMinimum);
+ Assert.Equal(10, providerTeamsPlan.PurchasedSeats);
+ Assert.Equal(30, providerTeamsPlan.AssignedSeats);
+ Assert.Equal(60 * teamsPlan.PasswordManager.SeatPrice, providerTeamsPlan.Cost);
+ Assert.Equal("Monthly", providerTeamsPlan.Cadence);
+
+ var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly);
+ var providerEnterprisePlan = providerSubscriptionDTO.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name);
+ Assert.NotNull(providerEnterprisePlan);
+ Assert.Equal(100, providerEnterprisePlan.SeatMinimum);
+ Assert.Equal(0, providerEnterprisePlan.PurchasedSeats);
+ Assert.Equal(90, providerEnterprisePlan.AssignedSeats);
+ Assert.Equal(100 * enterprisePlan.PasswordManager.SeatPrice, providerEnterprisePlan.Cost);
+ Assert.Equal("Monthly", providerEnterprisePlan.Cadence);
+ }
+}
diff --git a/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs
new file mode 100644
index 0000000000..e75f4bb59e
--- /dev/null
+++ b/test/Api.Test/Billing/Controllers/ProviderOrganizationControllerTests.cs
@@ -0,0 +1,168 @@
+using Bit.Api.Billing.Controllers;
+using Bit.Api.Billing.Models;
+using Bit.Core;
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Repositories;
+using Bit.Core.Billing.Commands;
+using Bit.Core.Context;
+using Bit.Core.Repositories;
+using Bit.Core.Services;
+using Bit.Infrastructure.EntityFramework.AdminConsole.Models.Provider;
+using Bit.Test.Common.AutoFixture;
+using Bit.Test.Common.AutoFixture.Attributes;
+using Microsoft.AspNetCore.Http.HttpResults;
+using NSubstitute;
+using NSubstitute.ReturnsExtensions;
+using Xunit;
+using ProviderOrganization = Bit.Core.AdminConsole.Entities.Provider.ProviderOrganization;
+
+namespace Bit.Api.Test.Billing.Controllers;
+
+[ControllerCustomize(typeof(ProviderOrganizationController))]
+[SutProviderCustomize]
+public class ProviderOrganizationControllerTests
+{
+ [Theory, BitAutoData]
+ public async Task UpdateAsync_FFDisabled_NotFound(
+ Guid providerId,
+ Guid providerOrganizationId,
+ UpdateProviderOrganizationRequestBody requestBody,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
+ .Returns(false);
+
+ var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
+
+ Assert.IsType(result);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized(
+ Guid providerId,
+ Guid providerOrganizationId,
+ UpdateProviderOrganizationRequestBody requestBody,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
+ .Returns(true);
+
+ sutProvider.GetDependency().ProviderProviderAdmin(providerId)
+ .Returns(false);
+
+ var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
+
+ Assert.IsType(result);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionAsync_NoProvider_NotFound(
+ Guid providerId,
+ Guid providerOrganizationId,
+ UpdateProviderOrganizationRequestBody requestBody,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
+ .Returns(true);
+
+ sutProvider.GetDependency().ProviderProviderAdmin(providerId)
+ .Returns(true);
+
+ sutProvider.GetDependency().GetByIdAsync(providerId)
+ .ReturnsNull();
+
+ var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
+
+ Assert.IsType(result);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionAsync_NoProviderOrganization_NotFound(
+ Guid providerId,
+ Guid providerOrganizationId,
+ UpdateProviderOrganizationRequestBody requestBody,
+ Provider provider,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
+ .Returns(true);
+
+ sutProvider.GetDependency().ProviderProviderAdmin(providerId)
+ .Returns(true);
+
+ sutProvider.GetDependency().GetByIdAsync(providerId)
+ .Returns(provider);
+
+ sutProvider.GetDependency().GetByIdAsync(providerOrganizationId)
+ .ReturnsNull();
+
+ var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
+
+ Assert.IsType(result);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionAsync_NoOrganization_ServerError(
+ Guid providerId,
+ Guid providerOrganizationId,
+ UpdateProviderOrganizationRequestBody requestBody,
+ Provider provider,
+ ProviderOrganization providerOrganization,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
+ .Returns(true);
+
+ sutProvider.GetDependency().ProviderProviderAdmin(providerId)
+ .Returns(true);
+
+ sutProvider.GetDependency().GetByIdAsync(providerId)
+ .Returns(provider);
+
+ sutProvider.GetDependency().GetByIdAsync(providerOrganizationId)
+ .Returns(providerOrganization);
+
+ sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId)
+ .ReturnsNull();
+
+ var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
+
+ Assert.IsType(result);
+ }
+
+ [Theory, BitAutoData]
+ public async Task GetSubscriptionAsync_NoContent(
+ Guid providerId,
+ Guid providerOrganizationId,
+ UpdateProviderOrganizationRequestBody requestBody,
+ Provider provider,
+ ProviderOrganization providerOrganization,
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)
+ .Returns(true);
+
+ sutProvider.GetDependency().ProviderProviderAdmin(providerId)
+ .Returns(true);
+
+ sutProvider.GetDependency().GetByIdAsync(providerId)
+ .Returns(provider);
+
+ sutProvider.GetDependency().GetByIdAsync(providerOrganizationId)
+ .Returns(providerOrganization);
+
+ sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId)
+ .Returns(organization);
+
+ var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody);
+
+ await sutProvider.GetDependency().Received(1)
+ .AssignSeatsToClientOrganization(
+ provider,
+ organization,
+ requestBody.AssignedSeats);
+
+ Assert.IsType(result);
+ }
+}
diff --git a/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs b/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs
new file mode 100644
index 0000000000..918b7c47a2
--- /dev/null
+++ b/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs
@@ -0,0 +1,339 @@
+using Bit.Core.AdminConsole.Entities;
+using Bit.Core.AdminConsole.Entities.Provider;
+using Bit.Core.Billing;
+using Bit.Core.Billing.Commands.Implementations;
+using Bit.Core.Billing.Entities;
+using Bit.Core.Billing.Queries;
+using Bit.Core.Billing.Repositories;
+using Bit.Core.Enums;
+using Bit.Core.Models.StaticStore;
+using Bit.Core.Repositories;
+using Bit.Core.Services;
+using Bit.Core.Utilities;
+using Bit.Test.Common.AutoFixture;
+using Bit.Test.Common.AutoFixture.Attributes;
+using NSubstitute;
+using Xunit;
+
+using static Bit.Core.Test.Billing.Utilities;
+
+namespace Bit.Core.Test.Billing.Commands;
+
+[SutProviderCustomize]
+public class AssignSeatsToClientOrganizationCommandTests
+{
+ [Theory, BitAutoData]
+ public Task AssignSeatsToClientOrganization_NullProvider_ArgumentNullException(
+ Organization organization,
+ int seats,
+ SutProvider sutProvider)
+ => Assert.ThrowsAsync(() =>
+ sutProvider.Sut.AssignSeatsToClientOrganization(null, organization, seats));
+
+ [Theory, BitAutoData]
+ public Task AssignSeatsToClientOrganization_NullOrganization_ArgumentNullException(
+ Provider provider,
+ int seats,
+ SutProvider sutProvider)
+ => Assert.ThrowsAsync(() =>
+ sutProvider.Sut.AssignSeatsToClientOrganization(provider, null, seats));
+
+ [Theory, BitAutoData]
+ public Task AssignSeatsToClientOrganization_NegativeSeats_BillingException(
+ Provider provider,
+ Organization organization,
+ SutProvider sutProvider)
+ => Assert.ThrowsAsync(() =>
+ sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, -5));
+
+ [Theory, BitAutoData]
+ public async Task AssignSeatsToClientOrganization_CurrentSeatsMatchesNewSeats_NoOp(
+ Provider provider,
+ Organization organization,
+ int seats,
+ SutProvider sutProvider)
+ {
+ organization.PlanType = PlanType.TeamsMonthly;
+
+ organization.Seats = seats;
+
+ await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats);
+
+ await sutProvider.GetDependency().DidNotReceive().GetByProviderId(provider.Id);
+ }
+
+ [Theory, BitAutoData]
+ public async Task AssignSeatsToClientOrganization_OrganizationPlanTypeDoesNotSupportConsolidatedBilling_ContactSupport(
+ Provider provider,
+ Organization organization,
+ int seats,
+ SutProvider sutProvider)
+ {
+ organization.PlanType = PlanType.FamiliesAnnually;
+
+ await ThrowsContactSupportAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats));
+ }
+
+ [Theory, BitAutoData]
+ public async Task AssignSeatsToClientOrganization_ProviderPlanIsNotConfigured_ContactSupport(
+ Provider provider,
+ Organization organization,
+ int seats,
+ SutProvider sutProvider)
+ {
+ organization.PlanType = PlanType.TeamsMonthly;
+
+ sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(new List
+ {
+ new ()
+ {
+ Id = Guid.NewGuid(),
+ PlanType = PlanType.TeamsMonthly,
+ ProviderId = provider.Id
+ }
+ });
+
+ await ThrowsContactSupportAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats));
+ }
+
+ [Theory, BitAutoData]
+ public async Task AssignSeatsToClientOrganization_BelowToBelow_Succeeds(
+ Provider provider,
+ Organization organization,
+ SutProvider sutProvider)
+ {
+ organization.Seats = 10;
+
+ organization.PlanType = PlanType.TeamsMonthly;
+
+ // Scale up 10 seats
+ const int seats = 20;
+
+ var providerPlans = new List
+ {
+ new()
+ {
+ Id = Guid.NewGuid(),
+ PlanType = PlanType.TeamsMonthly,
+ ProviderId = provider.Id,
+ PurchasedSeats = 0,
+ // 100 minimum
+ SeatMinimum = 100,
+ AllocatedSeats = 50
+ },
+ new()
+ {
+ Id = Guid.NewGuid(),
+ PlanType = PlanType.EnterpriseMonthly,
+ ProviderId = provider.Id,
+ PurchasedSeats = 0,
+ SeatMinimum = 500,
+ AllocatedSeats = 0
+ }
+ };
+
+ var providerPlan = providerPlans.First();
+
+ sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans);
+
+ // 50 seats currently assigned with a seat minimum of 100
+ sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(50);
+
+ await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats);
+
+ // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum
+ await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().AdjustSeats(
+ Arg.Any(),
+ Arg.Any(),
+ Arg.Any(),
+ Arg.Any