1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-27 15:52:13 -05:00

Merge branch 'main' into feature/phishing-detection

This commit is contained in:
Cy Okeke 2025-04-18 17:05:03 +01:00
commit 5a4bcdaabf
No known key found for this signature in database
GPG Key ID: 88B341B55C84B45C
457 changed files with 41827 additions and 4411 deletions

12
.github/CODEOWNERS vendored
View File

@ -20,16 +20,25 @@
# Database Operations for database changes
src/Sql/** @bitwarden/dept-dbops
util/EfShared/** @bitwarden/dept-dbops
util/Migrator/** @bitwarden/dept-dbops
util/Migrator/** @bitwarden/team-platform-dev # The Platform team owns the Migrator project code
util/Migrator/DbScripts/** @bitwarden/dept-dbops
util/Migrator/DbScripts_finalization/** @bitwarden/dept-dbops
util/Migrator/DbScripts_transition/** @bitwarden/dept-dbops
util/Migrator/MySql/** @bitwarden/dept-dbops
util/MySqlMigrations/** @bitwarden/dept-dbops
util/PostgresMigrations/** @bitwarden/dept-dbops
util/SqlServerEFScaffold/** @bitwarden/dept-dbops
util/SqliteMigrations/** @bitwarden/dept-dbops
# Shared util projects
util/Setup/** @bitwarden/dept-bre @bitwarden/team-platform-dev
# Auth team
**/Auth @bitwarden/team-auth-dev
bitwarden_license/src/Sso @bitwarden/team-auth-dev
src/Identity @bitwarden/team-auth-dev
src/Core/Identity @bitwarden/team-auth-dev
src/Core/IdentityServer @bitwarden/team-auth-dev
# Key Management team
**/KeyManagement @bitwarden/team-key-management-dev
@ -66,6 +75,7 @@ src/Admin/Views/Tools @bitwarden/team-billing-dev
# Platform team
.github/workflows/build.yml @bitwarden/team-platform-dev
.github/workflows/build_target.yml @bitwarden/team-platform-dev
.github/workflows/cleanup-after-pr.yml @bitwarden/team-platform-dev
.github/workflows/cleanup-rc-branch.yml @bitwarden/team-platform-dev
.github/workflows/repository-management.yml @bitwarden/team-platform-dev

View File

@ -7,22 +7,18 @@ on:
- "main"
- "rc"
- "hotfix-rc"
pull_request_target:
pull_request:
types: [opened, synchronize]
workflow_call:
inputs: {}
env:
_AZ_REGISTRY: "bitwardenprod.azurecr.io"
jobs:
check-run:
name: Check PR run
uses: bitwarden/gh-actions/.github/workflows/check-run.yml@main
lint:
name: Lint
runs-on: ubuntu-22.04
needs:
- check-run
steps:
- name: Check out repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -40,6 +36,8 @@ jobs:
runs-on: ubuntu-22.04
needs:
- lint
outputs:
has_secrets: ${{ steps.check-secrets.outputs.has_secrets }}
strategy:
fail-fast: false
matrix:
@ -75,6 +73,14 @@ jobs:
base_path: ./bitwarden_license/src
node: true
steps:
- name: Check secrets
id: check-secrets
env:
AZURE_KV_CI_SERVICE_PRINCIPAL: ${{ secrets.AZURE_KV_CI_SERVICE_PRINCIPAL }}
run: |
has_secrets=${{ secrets.AZURE_KV_CI_SERVICE_PRINCIPAL != '' }}
echo "has_secrets=$has_secrets" >> $GITHUB_OUTPUT
- name: Check out repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
@ -134,6 +140,7 @@ jobs:
id-token: write
needs:
- build-artifacts
if: ${{ needs.build-artifacts.outputs.has_secrets == 'true' }}
strategy:
fail-fast: false
matrix:
@ -227,7 +234,7 @@ jobs:
- name: Generate Docker image tag
id: tag
run: |
if [[ "${GITHUB_EVENT_NAME}" == "pull_request_target" ]]; then
if [[ "${GITHUB_EVENT_NAME}" == "pull_request" ]]; then
IMAGE_TAG=$(echo "${GITHUB_HEAD_REF}" | sed "s#/#-#g")
else
IMAGE_TAG=$(echo "${GITHUB_REF:11}" | sed "s#/#-#g")
@ -289,11 +296,11 @@ jobs:
"GH_PAT=${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}"
- name: Install Cosign
if: github.event_name != 'pull_request_target' && github.ref == 'refs/heads/main'
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
uses: sigstore/cosign-installer@dc72c7d5c4d10cd6bcb8cf6e3fd625a9e5e537da # v3.7.0
- name: Sign image with Cosign
if: github.event_name != 'pull_request_target' && github.ref == 'refs/heads/main'
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
env:
DIGEST: ${{ steps.build-docker.outputs.digest }}
TAGS: ${{ steps.image-tags.outputs.tags }}
@ -317,6 +324,8 @@ jobs:
uses: github/codeql-action/upload-sarif@dd746615b3b9d728a6a37ca2045b68ca76d4841a # v3.28.8
with:
sarif_file: ${{ steps.container-scan.outputs.sarif }}
sha: ${{ contains(github.event_name, 'pull_request') && github.event.pull_request.head.sha || github.sha }}
ref: ${{ contains(github.event_name, 'pull_request') && format('refs/pull/{0}/head', github.event.pull_request.number) || github.ref }}
upload:
name: Upload
@ -341,7 +350,7 @@ jobs:
- name: Make Docker stubs
if: |
github.event_name != 'pull_request_target'
github.event_name != 'pull_request'
&& (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc')
run: |
# Set proper setup image based on branch
@ -383,7 +392,7 @@ jobs:
- name: Make Docker stub checksums
if: |
github.event_name != 'pull_request_target'
github.event_name != 'pull_request'
&& (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc')
run: |
sha256sum docker-stub-US.zip > docker-stub-US-sha256.txt
@ -391,7 +400,7 @@ jobs:
- name: Upload Docker stub US artifact
if: |
github.event_name != 'pull_request_target'
github.event_name != 'pull_request'
&& (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc')
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
with:
@ -401,7 +410,7 @@ jobs:
- name: Upload Docker stub EU artifact
if: |
github.event_name != 'pull_request_target'
github.event_name != 'pull_request'
&& (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc')
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
with:
@ -411,7 +420,7 @@ jobs:
- name: Upload Docker stub US checksum artifact
if: |
github.event_name != 'pull_request_target'
github.event_name != 'pull_request'
&& (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc')
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
with:
@ -421,7 +430,7 @@ jobs:
- name: Upload Docker stub EU checksum artifact
if: |
github.event_name != 'pull_request_target'
github.event_name != 'pull_request'
&& (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc')
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
with:
@ -550,7 +559,7 @@ jobs:
self-host-build:
name: Trigger self-host build
if: |
github.event_name != 'pull_request_target'
github.event_name != 'pull_request'
&& (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc')
runs-on: ubuntu-22.04
needs:
@ -585,7 +594,7 @@ jobs:
trigger-k8s-deploy:
name: Trigger k8s deploy
if: github.event_name != 'pull_request_target' && github.ref == 'refs/heads/main'
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
runs-on: ubuntu-22.04
needs:
- build-docker
@ -618,57 +627,19 @@ jobs:
}
})
trigger-ee-updates:
name: Trigger Ephemeral Environment updates
setup-ephemeral-environment:
name: Setup Ephemeral Environment
needs: build-docker
if: |
github.event_name == 'pull_request_target'
&& contains(github.event.pull_request.labels.*.name, 'ephemeral-environment')
runs-on: ubuntu-24.04
needs:
- build-docker
steps:
- name: Log in to Azure - CI subscription
uses: Azure/login@e15b166166a8746d1a47596803bd8c1b595455cf # v1.6.0
with:
creds: ${{ secrets.AZURE_KV_CI_SERVICE_PRINCIPAL }}
- name: Retrieve GitHub PAT secrets
id: retrieve-secret-pat
uses: bitwarden/gh-actions/get-keyvault-secrets@main
with:
keyvault: "bitwarden-ci"
secrets: "github-pat-bitwarden-devops-bot-repo-scope"
- name: Trigger Ephemeral Environment update
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}
script: |
await github.rest.actions.createWorkflowDispatch({
owner: 'bitwarden',
repo: 'devops',
workflow_id: '_update_ephemeral_tags.yml',
ref: 'main',
inputs: {
ephemeral_env_branch: process.env.GITHUB_HEAD_REF
}
})
trigger-ephemeral-environment-sync:
name: Trigger Ephemeral Environment Sync
needs: trigger-ee-updates
if: |
github.event_name == 'pull_request_target'
needs.build-artifacts.outputs.has_secrets == 'true'
&& github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'ephemeral-environment')
uses: bitwarden/gh-actions/.github/workflows/_ephemeral_environment_manager.yml@main
with:
ephemeral_env_branch: process.env.GITHUB_HEAD_REF
project: server
sync_environment: true
pull_request_number: ${{ github.event.number }}
secrets: inherit
check-failures:
name: Check for failures
if: always()
@ -684,7 +655,7 @@ jobs:
steps:
- name: Check if any job failed
if: |
github.event_name != 'pull_request_target'
github.event_name != 'pull_request'
&& (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc')
&& contains(needs.*.result, 'failure')
run: exit 1

21
.github/workflows/build_target.yml vendored Normal file
View File

@ -0,0 +1,21 @@
name: Build on PR Target
on:
pull_request_target:
types: [opened, synchronize]
defaults:
run:
shell: bash
jobs:
check-run:
name: Check PR run
uses: bitwarden/gh-actions/.github/workflows/check-run.yml@main
run-workflow:
name: Run Build on PR Target
needs: check-run
if: ${{ github.event.pull_request.head.repo.full_name != github.repository }}
uses: ./.github/workflows/build.yml
secrets: inherit

View File

@ -5,34 +5,12 @@ on:
types: [labeled]
jobs:
trigger-ee-updates:
name: Trigger Ephemeral Environment updates
runs-on: ubuntu-24.04
setup-ephemeral-environment:
name: Setup Ephemeral Environment
if: github.event.label.name == 'ephemeral-environment'
steps:
- name: Log in to Azure - CI subscription
uses: Azure/login@e15b166166a8746d1a47596803bd8c1b595455cf # v1.6.0
with:
creds: ${{ secrets.AZURE_KV_CI_SERVICE_PRINCIPAL }}
- name: Retrieve GitHub PAT secrets
id: retrieve-secret-pat
uses: bitwarden/gh-actions/get-keyvault-secrets@main
with:
keyvault: "bitwarden-ci"
secrets: "github-pat-bitwarden-devops-bot-repo-scope"
- name: Trigger Ephemeral Environment update
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}
script: |
await github.rest.actions.createWorkflowDispatch({
owner: 'bitwarden',
repo: 'devops',
workflow_id: '_update_ephemeral_tags.yml',
ref: 'main',
inputs: {
ephemeral_env_branch: process.env.GITHUB_HEAD_REF
}
})
uses: bitwarden/gh-actions/.github/workflows/_ephemeral_environment_manager.yml@main
with:
project: server
pull_request_number: ${{ github.event.number }}
sync_environment: true
secrets: inherit

View File

@ -49,6 +49,8 @@ jobs:
uses: github/codeql-action/upload-sarif@dd746615b3b9d728a6a37ca2045b68ca76d4841a # v3.28.8
with:
sarif_file: cx_result.sarif
sha: ${{ contains(github.event_name, 'pull_request') && github.event.pull_request.head.sha || github.sha }}
ref: ${{ contains(github.event_name, 'pull_request') && format('refs/pull/{0}/head', github.event.pull_request.number) || github.ref }}
quality:
name: Quality scan

View File

@ -3,7 +3,7 @@
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<Version>2025.3.0</Version>
<Version>2025.4.1</Version>
<RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace>
<ImplicitUsings>enable</ImplicitUsings>

View File

@ -127,6 +127,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Infrastructure.Dapper.Test"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Events.IntegrationTest", "test\Events.IntegrationTest\Events.IntegrationTest.csproj", "{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Core.IntegrationTest", "test\Core.IntegrationTest\Core.IntegrationTest.csproj", "{3631BA42-6731-4118-A917-DAA43C5032B9}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -319,6 +321,10 @@ Global
{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401}.Debug|Any CPU.Build.0 = Debug|Any CPU
{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401}.Release|Any CPU.ActiveCfg = Release|Any CPU
{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401}.Release|Any CPU.Build.0 = Release|Any CPU
{3631BA42-6731-4118-A917-DAA43C5032B9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{3631BA42-6731-4118-A917-DAA43C5032B9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{3631BA42-6731-4118-A917-DAA43C5032B9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{3631BA42-6731-4118-A917-DAA43C5032B9}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -370,6 +376,7 @@ Global
{90D85D8F-5577-4570-A96E-5A2E185F0F6F} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F}
{4A725DB3-BE4F-4C23-9087-82D0610D67AF} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F}
{4F4C63A9-AEE2-48C4-AB86-A5BCD665E401} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F}
{3631BA42-6731-4118-A917-DAA43C5032B9} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {E01CBF68-2E20-425F-9EDB-E0A6510CA92F}

View File

@ -48,7 +48,7 @@ public class CreateProviderCommand : ICreateProviderCommand
await ProviderRepositoryCreateAsync(provider, ProviderStatusType.Created);
}
public async Task CreateMultiOrganizationEnterpriseAsync(Provider provider, string ownerEmail, PlanType plan, int minimumSeats)
public async Task CreateBusinessUnitAsync(Provider provider, string ownerEmail, PlanType plan, int minimumSeats)
{
var providerId = await CreateProviderAsync(provider, ownerEmail);

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.Providers.Interfaces;
@ -7,10 +8,12 @@ using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Microsoft.Extensions.DependencyInjection;
using Stripe;
namespace Bit.Commercial.Core.AdminConsole.Providers;
@ -28,6 +31,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
private readonly ISubscriberService _subscriberService;
private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery;
private readonly IPricingClient _pricingClient;
private readonly IAutomaticTaxStrategy _automaticTaxStrategy;
public RemoveOrganizationFromProviderCommand(
IEventService eventService,
@ -40,7 +44,8 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
IProviderBillingService providerBillingService,
ISubscriberService subscriberService,
IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery,
IPricingClient pricingClient)
IPricingClient pricingClient,
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
{
_eventService = eventService;
_mailService = mailService;
@ -53,6 +58,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
_subscriberService = subscriberService;
_hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery;
_pricingClient = pricingClient;
_automaticTaxStrategy = automaticTaxStrategy;
}
public async Task RemoveOrganizationFromProvider(
@ -107,10 +113,11 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
organization.IsValidClient() &&
!string.IsNullOrEmpty(organization.GatewayCustomerId))
{
await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions
var customer = await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions
{
Description = string.Empty,
Email = organization.BillingEmail
Email = organization.BillingEmail,
Expand = ["tax", "tax_ids"]
});
var plan = await _pricingClient.GetPlanOrThrow(organization.PlanType);
@ -120,7 +127,6 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
Customer = organization.GatewayCustomerId,
CollectionMethod = StripeConstants.CollectionMethod.SendInvoice,
DaysUntilDue = 30,
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true },
Metadata = new Dictionary<string, string>
{
{ "organizationId", organization.Id.ToString() }
@ -130,6 +136,18 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv
Items = [new SubscriptionItemOptions { Price = plan.PasswordManager.StripeSeatPlanId, Quantity = organization.Seats }]
};
if (_featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
_automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
{
subscriptionCreateOptions.AutomaticTax ??= new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
}
var subscription = await _stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
organization.GatewaySubscriptionId = subscription.Id;

View File

@ -692,10 +692,10 @@ public class ProviderService : IProviderService
throw new BadRequestException($"Managed Service Providers cannot manage organizations with the plan type {requestedType}. Only Teams (Monthly) and Enterprise (Monthly) are allowed.");
}
break;
case ProviderType.MultiOrganizationEnterprise:
case ProviderType.BusinessUnit:
if (requestedType is not (PlanType.EnterpriseMonthly or PlanType.EnterpriseAnnually))
{
throw new BadRequestException($"Multi-organization Enterprise Providers cannot manage organizations with the plan type {requestedType}. Only Enterprise (Monthly) and Enterprise (Annually) are allowed.");
throw new BadRequestException($"Business Unit Providers cannot manage organizations with the plan type {requestedType}. Only Enterprise (Monthly) and Enterprise (Annually) are allowed.");
}
break;
case ProviderType.Reseller:

View File

@ -0,0 +1,462 @@
#nullable enable
using System.Diagnostics.CodeAnalysis;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.DataProtection;
using Microsoft.Extensions.Logging;
using OneOf;
using Stripe;
namespace Bit.Commercial.Core.Billing;
[RequireFeature(FeatureFlagKeys.PM18770_EnableOrganizationBusinessUnitConversion)]
public class BusinessUnitConverter(
IDataProtectionProvider dataProtectionProvider,
GlobalSettings globalSettings,
ILogger<BusinessUnitConverter> logger,
IMailService mailService,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IPricingClient pricingClient,
IProviderOrganizationRepository providerOrganizationRepository,
IProviderPlanRepository providerPlanRepository,
IProviderRepository providerRepository,
IProviderUserRepository providerUserRepository,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
IUserRepository userRepository) : IBusinessUnitConverter
{
private readonly IDataProtector _dataProtector =
dataProtectionProvider.CreateProtector($"{nameof(BusinessUnitConverter)}DataProtector");
public async Task<Guid> FinalizeConversion(
Organization organization,
Guid userId,
string token,
string providerKey,
string organizationKey)
{
var user = await userRepository.GetByIdAsync(userId);
var (subscription, provider, providerOrganization, providerUser) = await ValidateFinalizationAsync(organization, user, token);
var existingPlan = await pricingClient.GetPlanOrThrow(organization.PlanType);
var updatedPlan = await pricingClient.GetPlanOrThrow(existingPlan.IsAnnual ? PlanType.EnterpriseAnnually : PlanType.EnterpriseMonthly);
// Bring organization under management.
organization.Plan = updatedPlan.Name;
organization.PlanType = updatedPlan.Type;
organization.MaxCollections = updatedPlan.PasswordManager.MaxCollections;
organization.MaxStorageGb = updatedPlan.PasswordManager.BaseStorageGb;
organization.UsePolicies = updatedPlan.HasPolicies;
organization.UseSso = updatedPlan.HasSso;
organization.UseGroups = updatedPlan.HasGroups;
organization.UseEvents = updatedPlan.HasEvents;
organization.UseDirectory = updatedPlan.HasDirectory;
organization.UseTotp = updatedPlan.HasTotp;
organization.Use2fa = updatedPlan.Has2fa;
organization.UseApi = updatedPlan.HasApi;
organization.UseResetPassword = updatedPlan.HasResetPassword;
organization.SelfHost = updatedPlan.HasSelfHost;
organization.UsersGetPremium = updatedPlan.UsersGetPremium;
organization.UseCustomPermissions = updatedPlan.HasCustomPermissions;
organization.UseScim = updatedPlan.HasScim;
organization.UseKeyConnector = updatedPlan.HasKeyConnector;
organization.MaxStorageGb = updatedPlan.PasswordManager.BaseStorageGb;
organization.BillingEmail = provider.BillingEmail!;
organization.GatewayCustomerId = null;
organization.GatewaySubscriptionId = null;
organization.ExpirationDate = null;
organization.MaxAutoscaleSeats = null;
organization.Status = OrganizationStatusType.Managed;
// Enable organization access via key exchange.
providerOrganization.Key = organizationKey;
// Complete provider setup.
provider.Gateway = GatewayType.Stripe;
provider.GatewayCustomerId = subscription.CustomerId;
provider.GatewaySubscriptionId = subscription.Id;
provider.Status = ProviderStatusType.Billable;
// Enable provider access via key exchange.
providerUser.Key = providerKey;
providerUser.Status = ProviderUserStatusType.Confirmed;
// Stripe requires that we clear all the custom fields from the invoice settings if we want to replace them.
await stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions
{
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields = []
}
});
var metadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.OrganizationId] = string.Empty,
[StripeConstants.MetadataKeys.ProviderId] = provider.Id.ToString(),
["convertedFrom"] = organization.Id.ToString()
};
var updateCustomer = stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions
{
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields = [
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = provider.SubscriberType(),
Value = provider.DisplayName()?.Length <= 30
? provider.DisplayName()
: provider.DisplayName()?[..30]
}
]
},
Metadata = metadata
});
// Find the existing password manager price on the subscription.
var passwordManagerItem = subscription.Items.First(item =>
{
var priceId = existingPlan.HasNonSeatBasedPasswordManagerPlan()
? existingPlan.PasswordManager.StripePlanId
: existingPlan.PasswordManager.StripeSeatPlanId;
return item.Price.Id == priceId;
});
// Get the new business unit price.
var updatedPriceId = ProviderPriceAdapter.GetActivePriceId(provider, updatedPlan.Type);
// Replace the existing password manager price with the new business unit price.
var updateSubscription =
stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
new SubscriptionUpdateOptions
{
Items = [
new SubscriptionItemOptions
{
Id = passwordManagerItem.Id,
Deleted = true
},
new SubscriptionItemOptions
{
Price = updatedPriceId,
Quantity = organization.Seats
}
],
Metadata = metadata
});
await Task.WhenAll(updateCustomer, updateSubscription);
// Complete database updates for provider setup.
await Task.WhenAll(
organizationRepository.ReplaceAsync(organization),
providerOrganizationRepository.ReplaceAsync(providerOrganization),
providerRepository.ReplaceAsync(provider),
providerUserRepository.ReplaceAsync(providerUser));
return provider.Id;
}
public async Task<OneOf<Guid, List<string>>> InitiateConversion(
Organization organization,
string providerAdminEmail)
{
var user = await userRepository.GetByEmailAsync(providerAdminEmail);
var problems = await ValidateInitiationAsync(organization, user);
if (problems is { Count: > 0 })
{
return problems;
}
var provider = await providerRepository.CreateAsync(new Provider
{
Name = organization.Name,
BillingEmail = organization.BillingEmail,
Status = ProviderStatusType.Pending,
UseEvents = true,
Type = ProviderType.BusinessUnit
});
var plan = await pricingClient.GetPlanOrThrow(organization.PlanType);
var managedPlanType = plan.IsAnnual
? PlanType.EnterpriseAnnually
: PlanType.EnterpriseMonthly;
var createProviderOrganization = providerOrganizationRepository.CreateAsync(new ProviderOrganization
{
ProviderId = provider.Id,
OrganizationId = organization.Id
});
var createProviderPlan = providerPlanRepository.CreateAsync(new ProviderPlan
{
ProviderId = provider.Id,
PlanType = managedPlanType,
SeatMinimum = 0,
PurchasedSeats = organization.Seats,
AllocatedSeats = organization.Seats
});
var createProviderUser = providerUserRepository.CreateAsync(new ProviderUser
{
ProviderId = provider.Id,
UserId = user!.Id,
Email = user.Email,
Status = ProviderUserStatusType.Invited,
Type = ProviderUserType.ProviderAdmin
});
await Task.WhenAll(createProviderOrganization, createProviderPlan, createProviderUser);
await SendInviteAsync(organization, user.Email);
return provider.Id;
}
public Task ResendConversionInvite(
Organization organization,
string providerAdminEmail) =>
IfConversionInProgressAsync(organization, providerAdminEmail,
async (_, _, providerUser) =>
{
if (!string.IsNullOrEmpty(providerUser.Email))
{
await SendInviteAsync(organization, providerUser.Email);
}
});
public Task ResetConversion(
Organization organization,
string providerAdminEmail) =>
IfConversionInProgressAsync(organization, providerAdminEmail,
async (provider, providerOrganization, providerUser) =>
{
var tasks = new List<Task>
{
providerOrganizationRepository.DeleteAsync(providerOrganization),
providerUserRepository.DeleteAsync(providerUser)
};
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
if (providerPlans is { Count: > 0 })
{
tasks.AddRange(providerPlans.Select(providerPlanRepository.DeleteAsync));
}
await Task.WhenAll(tasks);
await providerRepository.DeleteAsync(provider);
});
#region Utilities
private async Task IfConversionInProgressAsync(
Organization organization,
string providerAdminEmail,
Func<Provider, ProviderOrganization, ProviderUser, Task> callback)
{
var user = await userRepository.GetByEmailAsync(providerAdminEmail);
if (user == null)
{
return;
}
var provider = await providerRepository.GetByOrganizationIdAsync(organization.Id);
if (provider is not
{
Type: ProviderType.BusinessUnit,
Status: ProviderStatusType.Pending
})
{
return;
}
var providerUser = await providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id);
if (providerUser is
{
Type: ProviderUserType.ProviderAdmin,
Status: ProviderUserStatusType.Invited
})
{
var providerOrganization = await providerOrganizationRepository.GetByOrganizationId(organization.Id);
await callback(provider, providerOrganization!, providerUser);
}
}
private async Task SendInviteAsync(
Organization organization,
string providerAdminEmail)
{
var token = _dataProtector.Protect(
$"BusinessUnitConversionInvite {organization.Id} {providerAdminEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
await mailService.SendBusinessUnitConversionInviteAsync(organization, token, providerAdminEmail);
}
private async Task<(Subscription, Provider, ProviderOrganization, ProviderUser)> ValidateFinalizationAsync(
Organization organization,
User? user,
string token)
{
if (organization.PlanType.GetProductTier() != ProductTierType.Enterprise)
{
Fail("Organization must be on an enterprise plan.");
}
var subscription = await subscriberService.GetSubscription(organization);
if (subscription is not
{
Status:
StripeConstants.SubscriptionStatus.Active or
StripeConstants.SubscriptionStatus.Trialing or
StripeConstants.SubscriptionStatus.PastDue
})
{
Fail("Organization must have a valid subscription.");
}
if (user == null)
{
Fail("Provider admin must be a Bitwarden user.");
}
if (!CoreHelpers.TokenIsValid(
"BusinessUnitConversionInvite",
_dataProtector,
token,
user.Email,
organization.Id,
globalSettings.OrganizationInviteExpirationHours))
{
Fail("Email token is invalid.");
}
var organizationUser =
await organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id);
if (organizationUser is not
{
Status: OrganizationUserStatusType.Confirmed
})
{
Fail("Provider admin must be a confirmed member of the organization being converted.");
}
var provider = await providerRepository.GetByOrganizationIdAsync(organization.Id);
if (provider is not
{
Type: ProviderType.BusinessUnit,
Status: ProviderStatusType.Pending
})
{
Fail("Linked provider is not a pending business unit.");
}
var providerUser = await providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id);
if (providerUser is not
{
Type: ProviderUserType.ProviderAdmin,
Status: ProviderUserStatusType.Invited
})
{
Fail("Provider admin has not been invited.");
}
var providerOrganization = await providerOrganizationRepository.GetByOrganizationId(organization.Id);
return (subscription, provider, providerOrganization!, providerUser);
[DoesNotReturn]
void Fail(string scopedError)
{
logger.LogError("Could not finalize business unit conversion for organization ({OrganizationID}): {Error}",
organization.Id, scopedError);
throw new BillingException();
}
}
private async Task<List<string>?> ValidateInitiationAsync(
Organization organization,
User? user)
{
var problems = new List<string>();
if (organization.PlanType.GetProductTier() != ProductTierType.Enterprise)
{
problems.Add("Organization must be on an enterprise plan.");
}
var subscription = await subscriberService.GetSubscription(organization);
if (subscription is not
{
Status:
StripeConstants.SubscriptionStatus.Active or
StripeConstants.SubscriptionStatus.Trialing or
StripeConstants.SubscriptionStatus.PastDue
})
{
problems.Add("Organization must have a valid subscription.");
}
var providerOrganization = await providerOrganizationRepository.GetByOrganizationId(organization.Id);
if (providerOrganization != null)
{
problems.Add("Organization is already linked to a provider.");
}
if (user == null)
{
problems.Add("Provider admin must be a Bitwarden user.");
}
else
{
var organizationUser =
await organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id);
if (organizationUser is not
{
Status: OrganizationUserStatusType.Confirmed
})
{
problems.Add("Provider admin must be a confirmed member of the organization being converted.");
}
}
return problems.Count == 0 ? null : problems;
}
#endregion
}

View File

@ -14,6 +14,7 @@ using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.Billing.Services.Implementations.AutomaticTax;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
@ -22,6 +23,7 @@ using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using CsvHelper;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Stripe;
@ -29,10 +31,10 @@ namespace Bit.Commercial.Core.Billing;
public class ProviderBillingService(
IEventService eventService,
IFeatureService featureService,
IGlobalSettings globalSettings,
ILogger<ProviderBillingService> logger,
IOrganizationRepository organizationRepository,
IPaymentService paymentService,
IPricingClient pricingClient,
IProviderInvoiceItemRepository providerInvoiceItemRepository,
IProviderOrganizationRepository providerOrganizationRepository,
@ -40,7 +42,9 @@ public class ProviderBillingService(
IProviderUserRepository providerUserRepository,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
ITaxService taxService) : IProviderBillingService
ITaxService taxService,
[FromKeyedServices(AutomaticTaxFactory.BusinessUse)] IAutomaticTaxStrategy automaticTaxStrategy)
: IProviderBillingService
{
[RequireFeature(FeatureFlagKeys.P15179_AddExistingOrgsFromProviderPortal)]
public async Task AddExistingOrganization(
@ -143,36 +147,29 @@ public class ProviderBillingService(
public async Task ChangePlan(ChangeProviderPlanCommand command)
{
var plan = await providerPlanRepository.GetByIdAsync(command.ProviderPlanId);
var (provider, providerPlanId, newPlanType) = command;
if (plan == null)
var providerPlan = await providerPlanRepository.GetByIdAsync(providerPlanId);
if (providerPlan == null)
{
throw new BadRequestException("Provider plan not found.");
}
if (plan.PlanType == command.NewPlan)
if (providerPlan.PlanType == newPlanType)
{
return;
}
var oldPlanConfiguration = await pricingClient.GetPlanOrThrow(plan.PlanType);
var newPlanConfiguration = await pricingClient.GetPlanOrThrow(command.NewPlan);
var subscription = await subscriberService.GetSubscriptionOrThrow(provider);
plan.PlanType = command.NewPlan;
await providerPlanRepository.ReplaceAsync(plan);
var oldPriceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType);
var newPriceId = ProviderPriceAdapter.GetPriceId(provider, subscription, newPlanType);
Subscription subscription;
try
{
subscription = await stripeAdapter.ProviderSubscriptionGetAsync(command.GatewaySubscriptionId, plan.ProviderId);
}
catch (InvalidOperationException)
{
throw new ConflictException("Subscription not found.");
}
providerPlan.PlanType = newPlanType;
await providerPlanRepository.ReplaceAsync(providerPlan);
var oldSubscriptionItem = subscription.Items.SingleOrDefault(x =>
x.Price.Id == oldPlanConfiguration.PasswordManager.StripeProviderPortalSeatPlanId);
var oldSubscriptionItem = subscription.Items.SingleOrDefault(x => x.Price.Id == oldPriceId);
var updateOptions = new SubscriptionUpdateOptions
{
@ -180,7 +177,7 @@ public class ProviderBillingService(
[
new SubscriptionItemOptions
{
Price = newPlanConfiguration.PasswordManager.StripeProviderPortalSeatPlanId,
Price = newPriceId,
Quantity = oldSubscriptionItem!.Quantity
},
new SubscriptionItemOptions
@ -191,12 +188,14 @@ public class ProviderBillingService(
]
};
await stripeAdapter.SubscriptionUpdateAsync(command.GatewaySubscriptionId, updateOptions);
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, updateOptions);
// Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId)
// 1. Retrieve PlanType and PlanName for ProviderPlan
// 2. Assign PlanType & PlanName to Organization
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(plan.ProviderId);
var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId);
var newPlan = await pricingClient.GetPlanOrThrow(newPlanType);
foreach (var providerOrganization in providerOrganizations)
{
@ -205,8 +204,8 @@ public class ProviderBillingService(
{
throw new ConflictException($"Organization '{providerOrganization.Id}' not found.");
}
organization.PlanType = command.NewPlan;
organization.Plan = newPlanConfiguration.Name;
organization.PlanType = newPlanType;
organization.Plan = newPlan.Name;
await organizationRepository.ReplaceAsync(organization);
}
}
@ -400,7 +399,7 @@ public class ProviderBillingService(
var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment;
var update = CurrySeatScalingUpdate(
var scaleQuantityTo = CurrySeatScalingUpdate(
provider,
providerPlan,
newlyAssignedSeatTotal);
@ -423,9 +422,7 @@ public class ProviderBillingService(
else if (currentlyAssignedSeatTotal <= seatMinimum &&
newlyAssignedSeatTotal > seatMinimum)
{
await update(
seatMinimum,
newlyAssignedSeatTotal);
await scaleQuantityTo(newlyAssignedSeatTotal);
}
/*
* Above the limit => Above the limit:
@ -434,9 +431,7 @@ public class ProviderBillingService(
else if (currentlyAssignedSeatTotal > seatMinimum &&
newlyAssignedSeatTotal > seatMinimum)
{
await update(
currentlyAssignedSeatTotal,
newlyAssignedSeatTotal);
await scaleQuantityTo(newlyAssignedSeatTotal);
}
/*
* Above the limit => Below the limit:
@ -445,9 +440,7 @@ public class ProviderBillingService(
else if (currentlyAssignedSeatTotal > seatMinimum &&
newlyAssignedSeatTotal <= seatMinimum)
{
await update(
currentlyAssignedSeatTotal,
seatMinimum);
await scaleQuantityTo(seatMinimum);
}
}
@ -557,7 +550,8 @@ public class ProviderBillingService(
{
ArgumentNullException.ThrowIfNull(provider);
var customer = await subscriberService.GetCustomerOrThrow(provider);
var customerGetOptions = new CustomerGetOptions { Expand = ["tax", "tax_ids"] };
var customer = await subscriberService.GetCustomerOrThrow(provider, customerGetOptions);
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
@ -580,19 +574,17 @@ public class ProviderBillingService(
throw new BillingException();
}
var priceId = ProviderPriceAdapter.GetActivePriceId(provider, providerPlan.PlanType);
subscriptionItemOptionsList.Add(new SubscriptionItemOptions
{
Price = plan.PasswordManager.StripeProviderPortalSeatPlanId,
Price = priceId,
Quantity = providerPlan.SeatMinimum
});
}
var subscriptionCreateOptions = new SubscriptionCreateOptions
{
AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
},
CollectionMethod = StripeConstants.CollectionMethod.SendInvoice,
Customer = customer.Id,
DaysUntilDue = 30,
@ -605,6 +597,15 @@ public class ProviderBillingService(
ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations
};
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
automaticTaxStrategy.SetCreateOptions(subscriptionCreateOptions, customer);
}
else
{
subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true };
}
try
{
var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions);
@ -643,43 +644,37 @@ public class ProviderBillingService(
public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command)
{
if (command.Configuration.Any(x => x.SeatsMinimum < 0))
var (provider, updatedPlanConfigurations) = command;
if (updatedPlanConfigurations.Any(x => x.SeatsMinimum < 0))
{
throw new BadRequestException("Provider seat minimums must be at least 0.");
}
Subscription subscription;
try
{
subscription = await stripeAdapter.ProviderSubscriptionGetAsync(command.GatewaySubscriptionId, command.Id);
}
catch (InvalidOperationException)
{
throw new ConflictException("Subscription not found.");
}
var subscription = await subscriberService.GetSubscriptionOrThrow(provider);
var subscriptionItemOptionsList = new List<SubscriptionItemOptions>();
var providerPlans = await providerPlanRepository.GetByProviderId(command.Id);
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
foreach (var newPlanConfiguration in command.Configuration)
foreach (var updatedPlanConfiguration in updatedPlanConfigurations)
{
var (updatedPlanType, updatedSeatMinimum) = updatedPlanConfiguration;
var providerPlan =
providerPlans.Single(providerPlan => providerPlan.PlanType == newPlanConfiguration.Plan);
providerPlans.Single(providerPlan => providerPlan.PlanType == updatedPlanType);
if (providerPlan.SeatMinimum != newPlanConfiguration.SeatsMinimum)
if (providerPlan.SeatMinimum != updatedSeatMinimum)
{
var newPlan = await pricingClient.GetPlanOrThrow(newPlanConfiguration.Plan);
var priceId = newPlan.PasswordManager.StripeProviderPortalSeatPlanId;
var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, updatedPlanType);
var subscriptionItem = subscription.Items.First(item => item.Price.Id == priceId);
if (providerPlan.PurchasedSeats == 0)
{
if (providerPlan.AllocatedSeats > newPlanConfiguration.SeatsMinimum)
if (providerPlan.AllocatedSeats > updatedSeatMinimum)
{
providerPlan.PurchasedSeats = providerPlan.AllocatedSeats - newPlanConfiguration.SeatsMinimum;
providerPlan.PurchasedSeats = providerPlan.AllocatedSeats - updatedSeatMinimum;
subscriptionItemOptionsList.Add(new SubscriptionItemOptions
{
@ -694,7 +689,7 @@ public class ProviderBillingService(
{
Id = subscriptionItem.Id,
Price = priceId,
Quantity = newPlanConfiguration.SeatsMinimum
Quantity = updatedSeatMinimum
});
}
}
@ -702,9 +697,9 @@ public class ProviderBillingService(
{
var totalSeats = providerPlan.SeatMinimum + providerPlan.PurchasedSeats;
if (newPlanConfiguration.SeatsMinimum <= totalSeats)
if (updatedSeatMinimum <= totalSeats)
{
providerPlan.PurchasedSeats = totalSeats - newPlanConfiguration.SeatsMinimum;
providerPlan.PurchasedSeats = totalSeats - updatedSeatMinimum;
}
else
{
@ -713,12 +708,12 @@ public class ProviderBillingService(
{
Id = subscriptionItem.Id,
Price = priceId,
Quantity = newPlanConfiguration.SeatsMinimum
Quantity = updatedSeatMinimum
});
}
}
providerPlan.SeatMinimum = newPlanConfiguration.SeatsMinimum;
providerPlan.SeatMinimum = updatedSeatMinimum;
await providerPlanRepository.ReplaceAsync(providerPlan);
}
@ -726,23 +721,33 @@ public class ProviderBillingService(
if (subscriptionItemOptionsList.Count > 0)
{
await stripeAdapter.SubscriptionUpdateAsync(command.GatewaySubscriptionId,
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
new SubscriptionUpdateOptions { Items = subscriptionItemOptionsList });
}
}
private Func<int, int, Task> CurrySeatScalingUpdate(
private Func<int, Task> CurrySeatScalingUpdate(
Provider provider,
ProviderPlan providerPlan,
int newlyAssignedSeats) => async (currentlySubscribedSeats, newlySubscribedSeats) =>
int newlyAssignedSeats) => async newlySubscribedSeats =>
{
var plan = await pricingClient.GetPlanOrThrow(providerPlan.PlanType);
var subscription = await subscriberService.GetSubscriptionOrThrow(provider);
await paymentService.AdjustSeats(
provider,
plan,
currentlySubscribedSeats,
newlySubscribedSeats);
var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType);
var item = subscription.Items.First(item => item.Price.Id == priceId);
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions
{
Items = [
new SubscriptionItemOptions
{
Id = item.Id,
Price = priceId,
Quantity = newlySubscribedSeats
}
]
});
var newlyPurchasedSeats = newlySubscribedSeats > providerPlan.SeatMinimum
? newlySubscribedSeats - providerPlan.SeatMinimum
@ -786,7 +791,7 @@ public class ProviderBillingService(
Provider provider,
Organization organization)
{
if (provider.Type == ProviderType.MultiOrganizationEnterprise)
if (provider.Type == ProviderType.BusinessUnit)
{
return (await providerPlanRepository.GetByProviderId(provider.Id)).First().PlanType;
}

View File

@ -0,0 +1,133 @@
// ReSharper disable SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault
#nullable enable
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing;
using Bit.Core.Billing.Enums;
using Stripe;
namespace Bit.Commercial.Core.Billing;
public static class ProviderPriceAdapter
{
public static class MSP
{
public static class Active
{
public const string Enterprise = "provider-portal-enterprise-monthly-2025";
public const string Teams = "provider-portal-teams-monthly-2025";
}
public static class Legacy
{
public const string Enterprise = "password-manager-provider-portal-enterprise-monthly-2024";
public const string Teams = "password-manager-provider-portal-teams-monthly-2024";
public static readonly List<string> List = [Enterprise, Teams];
}
}
public static class BusinessUnit
{
public static class Active
{
public const string Annually = "business-unit-portal-enterprise-annually-2025";
public const string Monthly = "business-unit-portal-enterprise-monthly-2025";
}
public static class Legacy
{
public const string Annually = "password-manager-provider-portal-enterprise-annually-2024";
public const string Monthly = "password-manager-provider-portal-enterprise-monthly-2024";
public static readonly List<string> List = [Annually, Monthly];
}
}
/// <summary>
/// Uses the <paramref name="provider"/>'s <see cref="Provider.Type"/> and <paramref name="subscription"/> to determine
/// whether the <paramref name="provider"/> is on active or legacy pricing and then returns a Stripe price ID for the provided
/// <paramref name="planType"/> based on that determination.
/// </summary>
/// <param name="provider">The provider to get the Stripe price ID for.</param>
/// <param name="subscription">The provider's subscription.</param>
/// <param name="planType">The plan type correlating to the desired Stripe price ID.</param>
/// <returns>A Stripe <see cref="Stripe.Price"/> ID.</returns>
/// <exception cref="BillingException">Thrown when the provider's type is not <see cref="ProviderType.Msp"/> or <see cref="ProviderType.BusinessUnit"/>.</exception>
/// <exception cref="BillingException">Thrown when the provided <see cref="planType"/> does not relate to a Stripe price ID.</exception>
public static string GetPriceId(
Provider provider,
Subscription subscription,
PlanType planType)
{
var priceIds = subscription.Items.Select(item => item.Price.Id);
var invalidPlanType =
new BillingException(message: $"PlanType {planType} does not have an associated provider price in Stripe");
return provider.Type switch
{
ProviderType.Msp => MSP.Legacy.List.Intersect(priceIds).Any()
? planType switch
{
PlanType.TeamsMonthly => MSP.Legacy.Teams,
PlanType.EnterpriseMonthly => MSP.Legacy.Enterprise,
_ => throw invalidPlanType
}
: planType switch
{
PlanType.TeamsMonthly => MSP.Active.Teams,
PlanType.EnterpriseMonthly => MSP.Active.Enterprise,
_ => throw invalidPlanType
},
ProviderType.BusinessUnit => BusinessUnit.Legacy.List.Intersect(priceIds).Any()
? planType switch
{
PlanType.EnterpriseAnnually => BusinessUnit.Legacy.Annually,
PlanType.EnterpriseMonthly => BusinessUnit.Legacy.Monthly,
_ => throw invalidPlanType
}
: planType switch
{
PlanType.EnterpriseAnnually => BusinessUnit.Active.Annually,
PlanType.EnterpriseMonthly => BusinessUnit.Active.Monthly,
_ => throw invalidPlanType
},
_ => throw new BillingException(
$"ProviderType {provider.Type} does not have any associated provider price IDs")
};
}
/// <summary>
/// Uses the <paramref name="provider"/>'s <see cref="Provider.Type"/> to return the active Stripe price ID for the provided
/// <paramref name="planType"/>.
/// </summary>
/// <param name="provider">The provider to get the Stripe price ID for.</param>
/// <param name="planType">The plan type correlating to the desired Stripe price ID.</param>
/// <returns>A Stripe <see cref="Stripe.Price"/> ID.</returns>
/// <exception cref="BillingException">Thrown when the provider's type is not <see cref="ProviderType.Msp"/> or <see cref="ProviderType.BusinessUnit"/>.</exception>
/// <exception cref="BillingException">Thrown when the provided <see cref="planType"/> does not relate to a Stripe price ID.</exception>
public static string GetActivePriceId(
Provider provider,
PlanType planType)
{
var invalidPlanType =
new BillingException(message: $"PlanType {planType} does not have an associated provider price in Stripe");
return provider.Type switch
{
ProviderType.Msp => planType switch
{
PlanType.TeamsMonthly => MSP.Active.Teams,
PlanType.EnterpriseMonthly => MSP.Active.Enterprise,
_ => throw invalidPlanType
},
ProviderType.BusinessUnit => planType switch
{
PlanType.EnterpriseAnnually => BusinessUnit.Active.Annually,
PlanType.EnterpriseMonthly => BusinessUnit.Active.Monthly,
_ => throw invalidPlanType
},
_ => throw new BillingException(
$"ProviderType {provider.Type} does not have any associated provider price IDs")
};
}
}

View File

@ -1,9 +1,9 @@
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.SecretsManager.Queries.Projects.Interfaces;
using Bit.Core.SecretsManager.Repositories;
using Bit.Core.Utilities;
namespace Bit.Commercial.Core.SecretsManager.Queries.Projects;
@ -11,13 +11,16 @@ public class MaxProjectsQuery : IMaxProjectsQuery
{
private readonly IOrganizationRepository _organizationRepository;
private readonly IProjectRepository _projectRepository;
private readonly IPricingClient _pricingClient;
public MaxProjectsQuery(
IOrganizationRepository organizationRepository,
IProjectRepository projectRepository)
IProjectRepository projectRepository,
IPricingClient pricingClient)
{
_organizationRepository = organizationRepository;
_projectRepository = projectRepository;
_pricingClient = pricingClient;
}
public async Task<(short? max, bool? overMax)> GetByOrgIdAsync(Guid organizationId, int projectsToAdd)
@ -28,8 +31,7 @@ public class MaxProjectsQuery : IMaxProjectsQuery
throw new NotFoundException();
}
// TODO: PRICING -> https://bitwarden.atlassian.net/browse/PM-17122
var plan = StaticStore.GetPlan(org.PlanType);
var plan = await _pricingClient.GetPlan(org.PlanType);
if (plan?.SecretsManager == null)
{
throw new BadRequestException("Existing plan not found.");

View File

@ -16,5 +16,6 @@ public static class ServiceCollectionExtensions
services.AddScoped<ICreateProviderCommand, CreateProviderCommand>();
services.AddScoped<IRemoveOrganizationFromProviderCommand, RemoveOrganizationFromProviderCommand>();
services.AddTransient<IProviderBillingService, ProviderBillingService>();
services.AddTransient<IBusinessUnitConverter, BusinessUnitConverter>();
}
}

View File

@ -1,10 +1,8 @@
using Bit.Core;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Scim.Groups.Interfaces;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
@ -24,10 +22,8 @@ public class GroupsController : Controller
private readonly IGetGroupsListQuery _getGroupsListQuery;
private readonly IDeleteGroupCommand _deleteGroupCommand;
private readonly IPatchGroupCommand _patchGroupCommand;
private readonly IPatchGroupCommandvNext _patchGroupCommandvNext;
private readonly IPostGroupCommand _postGroupCommand;
private readonly IPutGroupCommand _putGroupCommand;
private readonly IFeatureService _featureService;
public GroupsController(
IGroupRepository groupRepository,
@ -35,10 +31,8 @@ public class GroupsController : Controller
IGetGroupsListQuery getGroupsListQuery,
IDeleteGroupCommand deleteGroupCommand,
IPatchGroupCommand patchGroupCommand,
IPatchGroupCommandvNext patchGroupCommandvNext,
IPostGroupCommand postGroupCommand,
IPutGroupCommand putGroupCommand,
IFeatureService featureService
IPutGroupCommand putGroupCommand
)
{
_groupRepository = groupRepository;
@ -46,10 +40,8 @@ public class GroupsController : Controller
_getGroupsListQuery = getGroupsListQuery;
_deleteGroupCommand = deleteGroupCommand;
_patchGroupCommand = patchGroupCommand;
_patchGroupCommandvNext = patchGroupCommandvNext;
_postGroupCommand = postGroupCommand;
_putGroupCommand = putGroupCommand;
_featureService = featureService;
}
[HttpGet("{id}")]
@ -103,21 +95,13 @@ public class GroupsController : Controller
[HttpPatch("{id}")]
public async Task<IActionResult> Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model)
{
if (_featureService.IsEnabled(FeatureFlagKeys.ShortcutDuplicatePatchRequests))
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organizationId)
{
throw new NotFoundException("Group not found.");
}
await _patchGroupCommandvNext.PatchGroupAsync(group, model);
return new NoContentResult();
throw new NotFoundException("Group not found.");
}
var organization = await _organizationRepository.GetByIdAsync(organizationId);
await _patchGroupCommand.PatchGroupAsync(organization, id, model);
await _patchGroupCommand.PatchGroupAsync(group, model);
return new NoContentResult();
}

View File

@ -1,4 +1,5 @@
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
@ -23,7 +24,7 @@ public class UsersController : Controller
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
private readonly IPatchUserCommand _patchUserCommand;
private readonly IPostUserCommand _postUserCommand;
private readonly ILogger<UsersController> _logger;
private readonly IRestoreOrganizationUserCommand _restoreOrganizationUserCommand;
public UsersController(
IOrganizationUserRepository organizationUserRepository,
@ -32,7 +33,7 @@ public class UsersController : Controller
IRemoveOrganizationUserCommand removeOrganizationUserCommand,
IPatchUserCommand patchUserCommand,
IPostUserCommand postUserCommand,
ILogger<UsersController> logger)
IRestoreOrganizationUserCommand restoreOrganizationUserCommand)
{
_organizationUserRepository = organizationUserRepository;
_organizationService = organizationService;
@ -40,7 +41,7 @@ public class UsersController : Controller
_removeOrganizationUserCommand = removeOrganizationUserCommand;
_patchUserCommand = patchUserCommand;
_postUserCommand = postUserCommand;
_logger = logger;
_restoreOrganizationUserCommand = restoreOrganizationUserCommand;
}
[HttpGet("{id}")]
@ -93,7 +94,7 @@ public class UsersController : Controller
if (model.Active && orgUser.Status == OrganizationUserStatusType.Revoked)
{
await _organizationService.RestoreUserAsync(orgUser, EventSystemUser.SCIM);
await _restoreOrganizationUserCommand.RestoreUserAsync(orgUser, EventSystemUser.SCIM);
}
else if (!model.Active && orgUser.Status != OrganizationUserStatusType.Revoked)
{

View File

@ -5,5 +5,5 @@ namespace Bit.Scim.Groups.Interfaces;
public interface IPatchGroupCommand
{
Task PatchGroupAsync(Organization organization, Guid id, ScimPatchModel model);
Task PatchGroupAsync(Group group, ScimPatchModel model);
}

View File

@ -1,9 +0,0 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Scim.Models;
namespace Bit.Scim.Groups.Interfaces;
public interface IPatchGroupCommandvNext
{
Task PatchGroupAsync(Group group, ScimPatchModel model);
}

View File

@ -5,8 +5,10 @@ using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Scim.Groups.Interfaces;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
namespace Bit.Scim.Groups;
@ -16,118 +18,137 @@ public class PatchGroupCommand : IPatchGroupCommand
private readonly IGroupService _groupService;
private readonly IUpdateGroupCommand _updateGroupCommand;
private readonly ILogger<PatchGroupCommand> _logger;
private readonly IOrganizationRepository _organizationRepository;
public PatchGroupCommand(
IGroupRepository groupRepository,
IGroupService groupService,
IUpdateGroupCommand updateGroupCommand,
ILogger<PatchGroupCommand> logger)
ILogger<PatchGroupCommand> logger,
IOrganizationRepository organizationRepository)
{
_groupRepository = groupRepository;
_groupService = groupService;
_updateGroupCommand = updateGroupCommand;
_logger = logger;
_organizationRepository = organizationRepository;
}
public async Task PatchGroupAsync(Organization organization, Guid id, ScimPatchModel model)
public async Task PatchGroupAsync(Group group, ScimPatchModel model)
{
var group = await _groupRepository.GetByIdAsync(id);
if (group == null || group.OrganizationId != organization.Id)
{
throw new NotFoundException("Group not found.");
}
var operationHandled = false;
foreach (var operation in model.Operations)
{
// Replace operations
if (operation.Op?.ToLowerInvariant() == "replace")
{
// Replace a list of members
if (operation.Path?.ToLowerInvariant() == "members")
await HandleOperationAsync(group, operation);
}
}
private async Task HandleOperationAsync(Group group, ScimPatchModel.OperationModel operation)
{
switch (operation.Op?.ToLowerInvariant())
{
// Replace a list of members
case PatchOps.Replace when operation.Path?.ToLowerInvariant() == PatchPaths.Members:
{
var ids = GetOperationValueIds(operation.Value);
await _groupRepository.UpdateUsersAsync(group.Id, ids);
operationHandled = true;
break;
}
// Replace group name from path
else if (operation.Path?.ToLowerInvariant() == "displayname")
// Replace group name from path
case PatchOps.Replace when operation.Path?.ToLowerInvariant() == PatchPaths.DisplayName:
{
group.Name = operation.Value.GetString();
var organization = await _organizationRepository.GetByIdAsync(group.OrganizationId);
if (organization == null)
{
throw new NotFoundException();
}
await _updateGroupCommand.UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
operationHandled = true;
break;
}
// Replace group name from value object
else if (string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("displayName", out var displayNameProperty))
// Replace group name from value object
case PatchOps.Replace when
string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("displayName", out var displayNameProperty):
{
group.Name = displayNameProperty.GetString();
var organization = await _organizationRepository.GetByIdAsync(group.OrganizationId);
if (organization == null)
{
throw new NotFoundException();
}
await _updateGroupCommand.UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
operationHandled = true;
break;
}
}
// Add a single member
else if (operation.Op?.ToLowerInvariant() == "add" &&
case PatchOps.Add when
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.ToLowerInvariant().StartsWith("members[value eq "))
{
var addId = GetOperationPathId(operation.Path);
if (addId.HasValue)
operation.Path.StartsWith("members[value eq ", StringComparison.OrdinalIgnoreCase) &&
TryGetOperationPathId(operation.Path, out var addId):
{
await AddMembersAsync(group, [addId]);
break;
}
// Add a list of members
case PatchOps.Add when
operation.Path?.ToLowerInvariant() == PatchPaths.Members:
{
await AddMembersAsync(group, GetOperationValueIds(operation.Value));
break;
}
// Remove a single member
case PatchOps.Remove when
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.StartsWith("members[value eq ", StringComparison.OrdinalIgnoreCase) &&
TryGetOperationPathId(operation.Path, out var removeId):
{
await _groupService.DeleteUserAsync(group, removeId, EventSystemUser.SCIM);
break;
}
// Remove a list of members
case PatchOps.Remove when
operation.Path?.ToLowerInvariant() == PatchPaths.Members:
{
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
orgUserIds.Add(addId.Value);
foreach (var v in GetOperationValueIds(operation.Value))
{
orgUserIds.Remove(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true;
break;
}
}
// Add a list of members
else if (operation.Op?.ToLowerInvariant() == "add" &&
operation.Path?.ToLowerInvariant() == "members")
{
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
foreach (var v in GetOperationValueIds(operation.Value))
{
orgUserIds.Add(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true;
}
// Remove a single member
else if (operation.Op?.ToLowerInvariant() == "remove" &&
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.ToLowerInvariant().StartsWith("members[value eq "))
{
var removeId = GetOperationPathId(operation.Path);
if (removeId.HasValue)
{
await _groupService.DeleteUserAsync(group, removeId.Value, EventSystemUser.SCIM);
operationHandled = true;
}
}
// Remove a list of members
else if (operation.Op?.ToLowerInvariant() == "remove" &&
operation.Path?.ToLowerInvariant() == "members")
{
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
foreach (var v in GetOperationValueIds(operation.Value))
{
orgUserIds.Remove(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
operationHandled = true;
}
}
if (!operationHandled)
{
_logger.LogWarning("Group patch operation not handled: {0} : ",
string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}")));
default:
{
_logger.LogWarning("Group patch operation not handled: {OperationOp}:{OperationPath}", operation.Op, operation.Path);
break;
}
}
}
private List<Guid> GetOperationValueIds(JsonElement objArray)
private async Task AddMembersAsync(Group group, HashSet<Guid> usersToAdd)
{
var ids = new List<Guid>();
// Azure Entra ID is known to send redundant "add" requests for each existing member every time any member
// is removed. To avoid excessive load on the database, we check against the high availability replica and
// return early if they already exist.
var groupMembers = await _groupRepository.GetManyUserIdsByIdAsync(group.Id, useReadOnlyReplica: true);
if (usersToAdd.IsSubsetOf(groupMembers))
{
_logger.LogDebug("Ignoring duplicate SCIM request to add members {Members} to group {Group}", usersToAdd, group.Id);
return;
}
await _groupRepository.AddGroupUsersByIdAsync(group.Id, usersToAdd);
}
private static HashSet<Guid> GetOperationValueIds(JsonElement objArray)
{
var ids = new HashSet<Guid>();
foreach (var obj in objArray.EnumerateArray())
{
if (obj.TryGetProperty("value", out var valueProperty))
@ -141,13 +162,9 @@ public class PatchGroupCommand : IPatchGroupCommand
return ids;
}
private Guid? GetOperationPathId(string path)
private static bool TryGetOperationPathId(string path, out Guid pathId)
{
// Parse Guid from string like: members[value eq "{GUID}"}]
if (Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out var id))
{
return id;
}
return null;
return Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out pathId);
}
}

View File

@ -1,170 +0,0 @@
using System.Text.Json;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Scim.Groups.Interfaces;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
namespace Bit.Scim.Groups;
public class PatchGroupCommandvNext : IPatchGroupCommandvNext
{
private readonly IGroupRepository _groupRepository;
private readonly IGroupService _groupService;
private readonly IUpdateGroupCommand _updateGroupCommand;
private readonly ILogger<PatchGroupCommandvNext> _logger;
private readonly IOrganizationRepository _organizationRepository;
public PatchGroupCommandvNext(
IGroupRepository groupRepository,
IGroupService groupService,
IUpdateGroupCommand updateGroupCommand,
ILogger<PatchGroupCommandvNext> logger,
IOrganizationRepository organizationRepository)
{
_groupRepository = groupRepository;
_groupService = groupService;
_updateGroupCommand = updateGroupCommand;
_logger = logger;
_organizationRepository = organizationRepository;
}
public async Task PatchGroupAsync(Group group, ScimPatchModel model)
{
foreach (var operation in model.Operations)
{
await HandleOperationAsync(group, operation);
}
}
private async Task HandleOperationAsync(Group group, ScimPatchModel.OperationModel operation)
{
switch (operation.Op?.ToLowerInvariant())
{
// Replace a list of members
case PatchOps.Replace when operation.Path?.ToLowerInvariant() == PatchPaths.Members:
{
var ids = GetOperationValueIds(operation.Value);
await _groupRepository.UpdateUsersAsync(group.Id, ids);
break;
}
// Replace group name from path
case PatchOps.Replace when operation.Path?.ToLowerInvariant() == PatchPaths.DisplayName:
{
group.Name = operation.Value.GetString();
var organization = await _organizationRepository.GetByIdAsync(group.OrganizationId);
if (organization == null)
{
throw new NotFoundException();
}
await _updateGroupCommand.UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
break;
}
// Replace group name from value object
case PatchOps.Replace when
string.IsNullOrWhiteSpace(operation.Path) &&
operation.Value.TryGetProperty("displayName", out var displayNameProperty):
{
group.Name = displayNameProperty.GetString();
var organization = await _organizationRepository.GetByIdAsync(group.OrganizationId);
if (organization == null)
{
throw new NotFoundException();
}
await _updateGroupCommand.UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
break;
}
// Add a single member
case PatchOps.Add when
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.StartsWith("members[value eq ", StringComparison.OrdinalIgnoreCase) &&
TryGetOperationPathId(operation.Path, out var addId):
{
await AddMembersAsync(group, [addId]);
break;
}
// Add a list of members
case PatchOps.Add when
operation.Path?.ToLowerInvariant() == PatchPaths.Members:
{
await AddMembersAsync(group, GetOperationValueIds(operation.Value));
break;
}
// Remove a single member
case PatchOps.Remove when
!string.IsNullOrWhiteSpace(operation.Path) &&
operation.Path.StartsWith("members[value eq ", StringComparison.OrdinalIgnoreCase) &&
TryGetOperationPathId(operation.Path, out var removeId):
{
await _groupService.DeleteUserAsync(group, removeId, EventSystemUser.SCIM);
break;
}
// Remove a list of members
case PatchOps.Remove when
operation.Path?.ToLowerInvariant() == PatchPaths.Members:
{
var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet();
foreach (var v in GetOperationValueIds(operation.Value))
{
orgUserIds.Remove(v);
}
await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds);
break;
}
default:
{
_logger.LogWarning("Group patch operation not handled: {OperationOp}:{OperationPath}", operation.Op, operation.Path);
break;
}
}
}
private async Task AddMembersAsync(Group group, HashSet<Guid> usersToAdd)
{
// Azure Entra ID is known to send redundant "add" requests for each existing member every time any member
// is removed. To avoid excessive load on the database, we check against the high availability replica and
// return early if they already exist.
var groupMembers = await _groupRepository.GetManyUserIdsByIdAsync(group.Id, useReadOnlyReplica: true);
if (usersToAdd.IsSubsetOf(groupMembers))
{
_logger.LogDebug("Ignoring duplicate SCIM request to add members {Members} to group {Group}", usersToAdd, group.Id);
return;
}
await _groupRepository.AddGroupUsersByIdAsync(group.Id, usersToAdd);
}
private static HashSet<Guid> GetOperationValueIds(JsonElement objArray)
{
var ids = new HashSet<Guid>();
foreach (var obj in objArray.EnumerateArray())
{
if (obj.TryGetProperty("value", out var valueProperty))
{
if (valueProperty.TryGetGuid(out var guid))
{
ids.Add(guid);
}
}
}
return ids;
}
private static bool TryGetOperationPathId(string path, out Guid pathId)
{
// Parse Guid from string like: members[value eq "{GUID}"}]
return Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out pathId);
}
}

View File

@ -1,8 +1,11 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.Enums;
using Bit.Core.Models.Business;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data;
using Bit.Core.Utilities;
using OrganizationUserInvite = Bit.Core.Models.Business.OrganizationUserInvite;
namespace Bit.Scim.Models;
@ -10,7 +13,8 @@ public class ScimUserRequestModel : BaseScimUserModel
{
public ScimUserRequestModel()
: base(false)
{ }
{
}
public OrganizationUserInvite ToOrganizationUserInvite(ScimProviderType scimProvider)
{
@ -25,6 +29,31 @@ public class ScimUserRequestModel : BaseScimUserModel
};
}
public InviteOrganizationUsersRequest ToRequest(
ScimProviderType scimProvider,
InviteOrganization inviteOrganization,
DateTimeOffset performedAt)
{
var email = EmailForInvite(scimProvider);
if (string.IsNullOrWhiteSpace(email) || !Active)
{
throw new BadRequestException();
}
return new InviteOrganizationUsersRequest(
invites:
[
new Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models.OrganizationUserInvite(
email: email,
externalId: ExternalIdForInvite()
)
],
inviteOrganization: inviteOrganization,
performedBy: Guid.Empty, // SCIM does not have a user id
performedAt: performedAt);
}
private string EmailForInvite(ScimProviderType scimProvider)
{
var email = PrimaryEmail?.ToLowerInvariant();

View File

@ -1,4 +1,5 @@
using Bit.Core.Enums;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
@ -11,15 +12,18 @@ public class PatchUserCommand : IPatchUserCommand
{
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IOrganizationService _organizationService;
private readonly IRestoreOrganizationUserCommand _restoreOrganizationUserCommand;
private readonly ILogger<PatchUserCommand> _logger;
public PatchUserCommand(
IOrganizationUserRepository organizationUserRepository,
IOrganizationService organizationService,
IRestoreOrganizationUserCommand restoreOrganizationUserCommand,
ILogger<PatchUserCommand> logger)
{
_organizationUserRepository = organizationUserRepository;
_organizationService = organizationService;
_restoreOrganizationUserCommand = restoreOrganizationUserCommand;
_logger = logger;
}
@ -71,7 +75,7 @@ public class PatchUserCommand : IPatchUserCommand
{
if (active && orgUser.Status == OrganizationUserStatusType.Revoked)
{
await _organizationService.RestoreUserAsync(orgUser, EventSystemUser.SCIM);
await _restoreOrganizationUserCommand.RestoreUserAsync(orgUser, EventSystemUser.SCIM);
return true;
}
else if (!active && orgUser.Status != OrganizationUserStatusType.Revoked)

View File

@ -1,39 +1,99 @@
using Bit.Core.Enums;
#nullable enable
using Bit.Core;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Errors;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models;
using Bit.Core.Billing.Pricing;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Commands;
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Scim.Context;
using Bit.Scim.Models;
using Bit.Scim.Users.Interfaces;
using static Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Errors.ErrorMapper;
namespace Bit.Scim.Users;
public class PostUserCommand : IPostUserCommand
public class PostUserCommand(
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationService organizationService,
IPaymentService paymentService,
IScimContext scimContext,
IFeatureService featureService,
IInviteOrganizationUsersCommand inviteOrganizationUsersCommand,
TimeProvider timeProvider,
IPricingClient pricingClient)
: IPostUserCommand
{
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IOrganizationService _organizationService;
private readonly IPaymentService _paymentService;
private readonly IScimContext _scimContext;
public PostUserCommand(
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IOrganizationService organizationService,
IPaymentService paymentService,
IScimContext scimContext)
public async Task<OrganizationUserUserDetails?> PostUserAsync(Guid organizationId, ScimUserRequestModel model)
{
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_organizationService = organizationService;
_paymentService = paymentService;
_scimContext = scimContext;
if (featureService.IsEnabled(FeatureFlagKeys.ScimInviteUserOptimization) is false)
{
return await InviteScimOrganizationUserAsync(model, organizationId, scimContext.RequestScimProvider);
}
return await InviteScimOrganizationUserAsync_vNext(model, organizationId, scimContext.RequestScimProvider);
}
public async Task<OrganizationUserUserDetails> PostUserAsync(Guid organizationId, ScimUserRequestModel model)
private async Task<OrganizationUserUserDetails?> InviteScimOrganizationUserAsync_vNext(
ScimUserRequestModel model,
Guid organizationId,
ScimProviderType scimProvider)
{
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization is null)
{
throw new NotFoundException();
}
var plan = await pricingClient.GetPlanOrThrow(organization.PlanType);
var request = model.ToRequest(
scimProvider: scimProvider,
inviteOrganization: new InviteOrganization(organization, plan),
performedAt: timeProvider.GetUtcNow());
var orgUsers = await organizationUserRepository
.GetManyDetailsByOrganizationAsync(request.InviteOrganization.OrganizationId);
if (orgUsers.Any(existingUser =>
request.Invites.First().Email.Equals(existingUser.Email, StringComparison.OrdinalIgnoreCase) ||
request.Invites.First().ExternalId.Equals(existingUser.ExternalId, StringComparison.OrdinalIgnoreCase)))
{
throw new ConflictException("User already exists.");
}
var result = await inviteOrganizationUsersCommand.InviteScimOrganizationUserAsync(request);
var invitedOrganizationUserId = result switch
{
Success<ScimInviteOrganizationUsersResponse> success => success.Value.InvitedUser.Id,
Failure<ScimInviteOrganizationUsersResponse> failure when failure.Errors
.Any(x => x.Message == NoUsersToInviteError.Code) => (Guid?)null,
Failure<ScimInviteOrganizationUsersResponse> failure when failure.Errors.Length != 0 => throw MapToBitException(failure.Errors),
_ => throw new InvalidOperationException()
};
var organizationUser = invitedOrganizationUserId.HasValue
? await organizationUserRepository.GetDetailsByIdAsync(invitedOrganizationUserId.Value)
: null;
return organizationUser;
}
private async Task<OrganizationUserUserDetails?> InviteScimOrganizationUserAsync(
ScimUserRequestModel model,
Guid organizationId,
ScimProviderType scimProvider)
{
var scimProvider = _scimContext.RequestScimProvider;
var invite = model.ToOrganizationUserInvite(scimProvider);
var email = invite.Emails.Single();
@ -44,7 +104,7 @@ public class PostUserCommand : IPostUserCommand
throw new BadRequestException();
}
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var orgUsers = await organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
var orgUserByEmail = orgUsers.FirstOrDefault(ou => ou.Email?.ToLowerInvariant() == email);
if (orgUserByEmail != null)
{
@ -57,13 +117,21 @@ public class PostUserCommand : IPostUserCommand
throw new ConflictException();
}
var organization = await _organizationRepository.GetByIdAsync(organizationId);
var hasStandaloneSecretsManager = await _paymentService.HasSecretsManagerStandalone(organization);
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
var hasStandaloneSecretsManager = await paymentService.HasSecretsManagerStandalone(organization);
invite.AccessSecretsManager = hasStandaloneSecretsManager;
var invitedOrgUser = await _organizationService.InviteUserAsync(organizationId, invitingUserId: null, EventSystemUser.SCIM,
invite, externalId);
var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(invitedOrgUser.Id);
var invitedOrgUser = await organizationService.InviteUserAsync(organizationId, invitingUserId: null,
EventSystemUser.SCIM,
invite,
externalId);
var orgUser = await organizationUserRepository.GetDetailsByIdAsync(invitedOrgUser.Id);
return orgUser;
}

View File

@ -10,7 +10,6 @@ public static class ScimServiceCollectionExtensions
public static void AddScimGroupCommands(this IServiceCollection services)
{
services.AddScoped<IPatchGroupCommand, PatchGroupCommand>();
services.AddScoped<IPatchGroupCommandvNext, PatchGroupCommandvNext>();
services.AddScoped<IPostGroupCommand, PostGroupCommand>();
services.AddScoped<IPutGroupCommand, PutGroupCommand>();
}

View File

@ -63,7 +63,7 @@ public class CreateProviderCommandTests
}
[Theory, BitAutoData]
public async Task CreateMultiOrganizationEnterpriseAsync_Success(
public async Task CreateBusinessUnitAsync_Success(
Provider provider,
User user,
PlanType plan,
@ -71,13 +71,13 @@ public class CreateProviderCommandTests
SutProvider<CreateProviderCommand> sutProvider)
{
// Arrange
provider.Type = ProviderType.MultiOrganizationEnterprise;
provider.Type = ProviderType.BusinessUnit;
var userRepository = sutProvider.GetDependency<IUserRepository>();
userRepository.GetByEmailAsync(user.Email).Returns(user);
// Act
await sutProvider.Sut.CreateMultiOrganizationEnterpriseAsync(provider, user.Email, plan, minimumSeats);
await sutProvider.Sut.CreateBusinessUnitAsync(provider, user.Email, plan, minimumSeats);
// Assert
await sutProvider.GetDependency<IProviderRepository>().ReceivedWithAnyArgs().CreateAsync(provider);
@ -85,7 +85,7 @@ public class CreateProviderCommandTests
}
[Theory, BitAutoData]
public async Task CreateMultiOrganizationEnterpriseAsync_UserIdIsInvalid_Throws(
public async Task CreateBusinessUnitAsync_UserIdIsInvalid_Throws(
Provider provider,
SutProvider<CreateProviderCommand> sutProvider)
{
@ -94,7 +94,7 @@ public class CreateProviderCommandTests
// Act
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CreateMultiOrganizationEnterpriseAsync(provider, default, default, default));
() => sutProvider.Sut.CreateBusinessUnitAsync(provider, default, default, default));
// Assert
Assert.Contains("Invalid owner.", exception.Message);

View File

@ -228,6 +228,26 @@ public class RemoveOrganizationFromProviderCommandTests
Id = "subscription_id"
});
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == organization.GatewayCustomerId &&
options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice &&
options.DaysUntilDue == 30 &&
options.Metadata["organizationId"] == organization.Id.ToString() &&
options.OffSession == true &&
options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations &&
options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId &&
options.Items.First().Quantity == organization.Seats)
, Arg.Any<Customer>()))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization);
await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(options =>

View File

@ -0,0 +1,501 @@
#nullable enable
using System.Text;
using Bit.Commercial.Core.Billing;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.DataProtection;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
namespace Bit.Commercial.Core.Test.Billing;
public class BusinessUnitConverterTests
{
private readonly IDataProtectionProvider _dataProtectionProvider = Substitute.For<IDataProtectionProvider>();
private readonly GlobalSettings _globalSettings = new();
private readonly ILogger<BusinessUnitConverter> _logger = Substitute.For<ILogger<BusinessUnitConverter>>();
private readonly IMailService _mailService = Substitute.For<IMailService>();
private readonly IOrganizationRepository _organizationRepository = Substitute.For<IOrganizationRepository>();
private readonly IOrganizationUserRepository _organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly IProviderOrganizationRepository _providerOrganizationRepository = Substitute.For<IProviderOrganizationRepository>();
private readonly IProviderPlanRepository _providerPlanRepository = Substitute.For<IProviderPlanRepository>();
private readonly IProviderRepository _providerRepository = Substitute.For<IProviderRepository>();
private readonly IProviderUserRepository _providerUserRepository = Substitute.For<IProviderUserRepository>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly ISubscriberService _subscriberService = Substitute.For<ISubscriberService>();
private readonly IUserRepository _userRepository = Substitute.For<IUserRepository>();
private BusinessUnitConverter BuildConverter() => new(
_dataProtectionProvider,
_globalSettings,
_logger,
_mailService,
_organizationRepository,
_organizationUserRepository,
_pricingClient,
_providerOrganizationRepository,
_providerPlanRepository,
_providerRepository,
_providerUserRepository,
_stripeAdapter,
_subscriberService,
_userRepository);
#region FinalizeConversion
[Theory, BitAutoData]
public async Task FinalizeConversion_Succeeds_ReturnsProviderId(
Organization organization,
Guid userId,
string providerKey,
string organizationKey)
{
organization.PlanType = PlanType.EnterpriseAnnually2020;
var enterpriseAnnually2020 = StaticStore.GetPlan(PlanType.EnterpriseAnnually2020);
var subscription = new Subscription
{
Id = "subscription_id",
CustomerId = "customer_id",
Status = StripeConstants.SubscriptionStatus.Active,
Items = new StripeList<SubscriptionItem>
{
Data = [
new SubscriptionItem
{
Id = "subscription_item_id",
Price = new Price
{
Id = enterpriseAnnually2020.PasswordManager.StripeSeatPlanId
}
}
]
}
};
_subscriberService.GetSubscription(organization).Returns(subscription);
var user = new User
{
Id = Guid.NewGuid(),
Email = "provider-admin@example.com"
};
_userRepository.GetByIdAsync(userId).Returns(user);
var token = SetupDataProtection(organization, user.Email);
var organizationUser = new OrganizationUser { Status = OrganizationUserStatusType.Confirmed };
_organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id)
.Returns(organizationUser);
var provider = new Provider
{
Type = ProviderType.BusinessUnit,
Status = ProviderStatusType.Pending
};
_providerRepository.GetByOrganizationIdAsync(organization.Id).Returns(provider);
var providerUser = new ProviderUser
{
Type = ProviderUserType.ProviderAdmin,
Status = ProviderUserStatusType.Invited
};
_providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser);
var providerOrganization = new ProviderOrganization();
_providerOrganizationRepository.GetByOrganizationId(organization.Id).Returns(providerOrganization);
_pricingClient.GetPlanOrThrow(PlanType.EnterpriseAnnually2020)
.Returns(enterpriseAnnually2020);
var enterpriseAnnually = StaticStore.GetPlan(PlanType.EnterpriseAnnually);
_pricingClient.GetPlanOrThrow(PlanType.EnterpriseAnnually)
.Returns(enterpriseAnnually);
var businessUnitConverter = BuildConverter();
await businessUnitConverter.FinalizeConversion(organization, userId, token, providerKey, organizationKey);
await _stripeAdapter.Received(2).CustomerUpdateAsync(subscription.CustomerId, Arg.Any<CustomerUpdateOptions>());
var updatedPriceId = ProviderPriceAdapter.GetActivePriceId(provider, enterpriseAnnually.Type);
await _stripeAdapter.Received(1).SubscriptionUpdateAsync(subscription.Id, Arg.Is<SubscriptionUpdateOptions>(
arguments =>
arguments.Items.Count == 2 &&
arguments.Items[0].Id == "subscription_item_id" &&
arguments.Items[0].Deleted == true &&
arguments.Items[1].Price == updatedPriceId &&
arguments.Items[1].Quantity == organization.Seats));
await _organizationRepository.Received(1).ReplaceAsync(Arg.Is<Organization>(arguments =>
arguments.PlanType == PlanType.EnterpriseAnnually &&
arguments.Status == OrganizationStatusType.Managed &&
arguments.GatewayCustomerId == null &&
arguments.GatewaySubscriptionId == null));
await _providerOrganizationRepository.Received(1).ReplaceAsync(Arg.Is<ProviderOrganization>(arguments =>
arguments.Key == organizationKey));
await _providerRepository.Received(1).ReplaceAsync(Arg.Is<Provider>(arguments =>
arguments.Gateway == GatewayType.Stripe &&
arguments.GatewayCustomerId == subscription.CustomerId &&
arguments.GatewaySubscriptionId == subscription.Id &&
arguments.Status == ProviderStatusType.Billable));
await _providerUserRepository.Received(1).ReplaceAsync(Arg.Is<ProviderUser>(arguments =>
arguments.Key == providerKey &&
arguments.Status == ProviderUserStatusType.Confirmed));
}
/*
* Because the validation for finalization is not an applicative like initialization is,
* I'm just testing one specific failure here. I don't see much value in testing every single opportunity for failure.
*/
[Theory, BitAutoData]
public async Task FinalizeConversion_ValidationFails_ThrowsBillingException(
Organization organization,
Guid userId,
string token,
string providerKey,
string organizationKey)
{
organization.PlanType = PlanType.EnterpriseAnnually2020;
var subscription = new Subscription
{
Status = StripeConstants.SubscriptionStatus.Canceled
};
_subscriberService.GetSubscription(organization).Returns(subscription);
var businessUnitConverter = BuildConverter();
await Assert.ThrowsAsync<BillingException>(() =>
businessUnitConverter.FinalizeConversion(organization, userId, token, providerKey, organizationKey));
await _organizationUserRepository.DidNotReceiveWithAnyArgs()
.GetByOrganizationAsync(Arg.Any<Guid>(), Arg.Any<Guid>());
}
#endregion
#region InitiateConversion
[Theory, BitAutoData]
public async Task InitiateConversion_Succeeds_ReturnsProviderId(
Organization organization,
string providerAdminEmail)
{
organization.PlanType = PlanType.EnterpriseAnnually;
_subscriberService.GetSubscription(organization).Returns(new Subscription
{
Status = StripeConstants.SubscriptionStatus.Active
});
var user = new User
{
Id = Guid.NewGuid(),
Email = providerAdminEmail
};
_userRepository.GetByEmailAsync(providerAdminEmail).Returns(user);
var organizationUser = new OrganizationUser { Status = OrganizationUserStatusType.Confirmed };
_organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id)
.Returns(organizationUser);
var provider = new Provider { Id = Guid.NewGuid() };
_providerRepository.CreateAsync(Arg.Is<Provider>(argument =>
argument.Name == organization.Name &&
argument.BillingEmail == organization.BillingEmail &&
argument.Status == ProviderStatusType.Pending &&
argument.Type == ProviderType.BusinessUnit)).Returns(provider);
var plan = StaticStore.GetPlan(organization.PlanType);
_pricingClient.GetPlanOrThrow(organization.PlanType).Returns(plan);
var token = SetupDataProtection(organization, providerAdminEmail);
var businessUnitConverter = BuildConverter();
var result = await businessUnitConverter.InitiateConversion(organization, providerAdminEmail);
Assert.True(result.IsT0);
var providerId = result.AsT0;
Assert.Equal(provider.Id, providerId);
await _providerOrganizationRepository.Received(1).CreateAsync(
Arg.Is<ProviderOrganization>(argument =>
argument.ProviderId == provider.Id &&
argument.OrganizationId == organization.Id));
await _providerPlanRepository.Received(1).CreateAsync(
Arg.Is<ProviderPlan>(argument =>
argument.ProviderId == provider.Id &&
argument.PlanType == PlanType.EnterpriseAnnually &&
argument.SeatMinimum == 0 &&
argument.PurchasedSeats == organization.Seats &&
argument.AllocatedSeats == organization.Seats));
await _providerUserRepository.Received(1).CreateAsync(
Arg.Is<ProviderUser>(argument =>
argument.ProviderId == provider.Id &&
argument.UserId == user.Id &&
argument.Email == user.Email &&
argument.Status == ProviderUserStatusType.Invited &&
argument.Type == ProviderUserType.ProviderAdmin));
await _mailService.Received(1).SendBusinessUnitConversionInviteAsync(
organization,
token,
user.Email);
}
[Theory, BitAutoData]
public async Task InitiateConversion_ValidationFails_ReturnsErrors(
Organization organization,
string providerAdminEmail)
{
organization.PlanType = PlanType.TeamsMonthly;
_subscriberService.GetSubscription(organization).Returns(new Subscription
{
Status = StripeConstants.SubscriptionStatus.Canceled
});
var user = new User
{
Id = Guid.NewGuid(),
Email = providerAdminEmail
};
_providerOrganizationRepository.GetByOrganizationId(organization.Id)
.Returns(new ProviderOrganization());
_userRepository.GetByEmailAsync(providerAdminEmail).Returns(user);
var organizationUser = new OrganizationUser { Status = OrganizationUserStatusType.Invited };
_organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id)
.Returns(organizationUser);
var businessUnitConverter = BuildConverter();
var result = await businessUnitConverter.InitiateConversion(organization, providerAdminEmail);
Assert.True(result.IsT1);
var problems = result.AsT1;
Assert.Contains("Organization must be on an enterprise plan.", problems);
Assert.Contains("Organization must have a valid subscription.", problems);
Assert.Contains("Organization is already linked to a provider.", problems);
Assert.Contains("Provider admin must be a confirmed member of the organization being converted.", problems);
}
#endregion
#region ResendConversionInvite
[Theory, BitAutoData]
public async Task ResendConversionInvite_ConversionInProgress_Succeeds(
Organization organization,
string providerAdminEmail)
{
SetupConversionInProgress(organization, providerAdminEmail);
var token = SetupDataProtection(organization, providerAdminEmail);
var businessUnitConverter = BuildConverter();
await businessUnitConverter.ResendConversionInvite(organization, providerAdminEmail);
await _mailService.Received(1).SendBusinessUnitConversionInviteAsync(
organization,
token,
providerAdminEmail);
}
[Theory, BitAutoData]
public async Task ResendConversionInvite_NoConversionInProgress_DoesNothing(
Organization organization,
string providerAdminEmail)
{
SetupDataProtection(organization, providerAdminEmail);
var businessUnitConverter = BuildConverter();
await businessUnitConverter.ResendConversionInvite(organization, providerAdminEmail);
await _mailService.DidNotReceiveWithAnyArgs().SendBusinessUnitConversionInviteAsync(
Arg.Any<Organization>(),
Arg.Any<string>(),
Arg.Any<string>());
}
#endregion
#region ResetConversion
[Theory, BitAutoData]
public async Task ResetConversion_ConversionInProgress_Succeeds(
Organization organization,
string providerAdminEmail)
{
var (provider, providerOrganization, providerUser, providerPlan) = SetupConversionInProgress(organization, providerAdminEmail);
var businessUnitConverter = BuildConverter();
await businessUnitConverter.ResetConversion(organization, providerAdminEmail);
await _providerOrganizationRepository.Received(1)
.DeleteAsync(providerOrganization);
await _providerUserRepository.Received(1)
.DeleteAsync(providerUser);
await _providerPlanRepository.Received(1)
.DeleteAsync(providerPlan);
await _providerRepository.Received(1)
.DeleteAsync(provider);
}
[Theory, BitAutoData]
public async Task ResetConversion_NoConversionInProgress_DoesNothing(
Organization organization,
string providerAdminEmail)
{
var businessUnitConverter = BuildConverter();
await businessUnitConverter.ResetConversion(organization, providerAdminEmail);
await _providerOrganizationRepository.DidNotReceiveWithAnyArgs()
.DeleteAsync(Arg.Any<ProviderOrganization>());
await _providerUserRepository.DidNotReceiveWithAnyArgs()
.DeleteAsync(Arg.Any<ProviderUser>());
await _providerPlanRepository.DidNotReceiveWithAnyArgs()
.DeleteAsync(Arg.Any<ProviderPlan>());
await _providerRepository.DidNotReceiveWithAnyArgs()
.DeleteAsync(Arg.Any<Provider>());
}
#endregion
#region Utilities
private string SetupDataProtection(
Organization organization,
string providerAdminEmail)
{
var dataProtector = new MockDataProtector(organization, providerAdminEmail);
_dataProtectionProvider.CreateProtector($"{nameof(BusinessUnitConverter)}DataProtector").Returns(dataProtector);
return dataProtector.Protect(dataProtector.Token);
}
private (Provider, ProviderOrganization, ProviderUser, ProviderPlan) SetupConversionInProgress(
Organization organization,
string providerAdminEmail)
{
var user = new User { Id = Guid.NewGuid() };
_userRepository.GetByEmailAsync(providerAdminEmail).Returns(user);
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.BusinessUnit,
Status = ProviderStatusType.Pending
};
_providerRepository.GetByOrganizationIdAsync(organization.Id).Returns(provider);
var providerUser = new ProviderUser
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
UserId = user.Id,
Type = ProviderUserType.ProviderAdmin,
Status = ProviderUserStatusType.Invited,
Email = providerAdminEmail
};
_providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id)
.Returns(providerUser);
var providerOrganization = new ProviderOrganization
{
Id = Guid.NewGuid(),
OrganizationId = organization.Id,
ProviderId = provider.Id
};
_providerOrganizationRepository.GetByOrganizationId(organization.Id)
.Returns(providerOrganization);
var providerPlan = new ProviderPlan
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.EnterpriseAnnually
};
_providerPlanRepository.GetByProviderId(provider.Id).Returns([providerPlan]);
return (provider, providerOrganization, providerUser, providerPlan);
}
#endregion
}
public class MockDataProtector(
Organization organization,
string providerAdminEmail) : IDataProtector
{
public string Token = $"BusinessUnitConversionInvite {organization.Id} {providerAdminEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}";
public IDataProtector CreateProtector(string purpose) => this;
public byte[] Protect(byte[] plaintext) => Encoding.UTF8.GetBytes(Token);
public byte[] Unprotect(byte[] protectedData) => Encoding.UTF8.GetBytes(Token);
}

View File

@ -4,6 +4,7 @@ using Bit.Commercial.Core.Billing;
using Bit.Commercial.Core.Billing.Models;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
@ -115,6 +116,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.Type = ProviderType.BusinessUnit;
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
var existingPlan = new ProviderPlan
{
@ -132,10 +135,7 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(existingPlan.PlanType)
.Returns(StaticStore.GetPlan(existingPlan.PlanType));
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.ProviderSubscriptionGetAsync(
Arg.Is(provider.GatewaySubscriptionId),
Arg.Is(provider.Id))
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider)
.Returns(new Subscription
{
Id = provider.GatewaySubscriptionId,
@ -158,7 +158,7 @@ public class ProviderBillingServiceTests
});
var command =
new ChangeProviderPlanCommand(providerPlanId, PlanType.EnterpriseMonthly, provider.GatewaySubscriptionId);
new ChangeProviderPlanCommand(provider, providerPlanId, PlanType.EnterpriseMonthly);
sutProvider.GetDependency<IPricingClient>().GetPlanOrThrow(command.NewPlan)
.Returns(StaticStore.GetPlan(command.NewPlan));
@ -170,6 +170,8 @@ public class ProviderBillingServiceTests
await providerPlanRepository.Received(1)
.ReplaceAsync(Arg.Is<ProviderPlan>(p => p.PlanType == PlanType.EnterpriseMonthly));
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
await stripeAdapter.Received(1)
.SubscriptionUpdateAsync(
Arg.Is(provider.GatewaySubscriptionId),
@ -405,6 +407,23 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans);
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } },
new SubscriptionItem
{
Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise }
}
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
// 50 seats currently assigned with a seat minimum of 100
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
@ -427,11 +446,9 @@ public class ProviderBillingServiceTests
await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10);
// 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum
await sutProvider.GetDependency<IPaymentService>().DidNotReceiveWithAnyArgs().AdjustSeats(
Arg.Any<Provider>(),
Arg.Any<Bit.Core.Models.StaticStore.Plan>(),
Arg.Any<int>(),
Arg.Any<int>());
await sutProvider.GetDependency<IStripeAdapter>().DidNotReceiveWithAnyArgs().SubscriptionUpdateAsync(
Arg.Any<string>(),
Arg.Any<SubscriptionUpdateOptions>());
await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>(
pPlan => pPlan.AllocatedSeats == 60));
@ -474,6 +491,23 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans);
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } },
new SubscriptionItem
{
Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise }
}
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
// 95 seats currently assigned with a seat minimum of 100
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
@ -496,11 +530,12 @@ public class ProviderBillingServiceTests
await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10);
// 95 current + 10 seat scale = 105 seats, 5 above the minimum
await sutProvider.GetDependency<IPaymentService>().Received(1).AdjustSeats(
provider,
StaticStore.GetPlan(providerPlan.PlanType),
providerPlan.SeatMinimum!.Value,
105);
await sutProvider.GetDependency<IStripeAdapter>().Received(1).SubscriptionUpdateAsync(
provider.GatewaySubscriptionId,
Arg.Is<SubscriptionUpdateOptions>(
options =>
options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams &&
options.Items.First().Quantity == 105));
// 105 total seats - 100 minimum = 5 purchased seats
await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>(
@ -544,6 +579,23 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans);
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } },
new SubscriptionItem
{
Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise }
}
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
// 110 seats currently assigned with a seat minimum of 100
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
@ -566,11 +618,12 @@ public class ProviderBillingServiceTests
await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10);
// 110 current + 10 seat scale up = 120 seats
await sutProvider.GetDependency<IPaymentService>().Received(1).AdjustSeats(
provider,
StaticStore.GetPlan(providerPlan.PlanType),
110,
120);
await sutProvider.GetDependency<IStripeAdapter>().Received(1).SubscriptionUpdateAsync(
provider.GatewaySubscriptionId,
Arg.Is<SubscriptionUpdateOptions>(
options =>
options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams &&
options.Items.First().Quantity == 120));
// 120 total seats - 100 seat minimum = 20 purchased seats
await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>(
@ -614,6 +667,23 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id).Returns(providerPlans);
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } },
new SubscriptionItem
{
Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise }
}
]
}
};
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
// 110 seats currently assigned with a seat minimum of 100
var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
@ -636,11 +706,12 @@ public class ProviderBillingServiceTests
await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, -30);
// 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum.
await sutProvider.GetDependency<IPaymentService>().Received(1).AdjustSeats(
provider,
StaticStore.GetPlan(providerPlan.PlanType),
110,
providerPlan.SeatMinimum!.Value);
await sutProvider.GetDependency<IStripeAdapter>().Received(1).SubscriptionUpdateAsync(
provider.GatewaySubscriptionId,
Arg.Is<SubscriptionUpdateOptions>(
options =>
options.Items.First().Price == ProviderPriceAdapter.MSP.Active.Teams &&
options.Items.First().Quantity == providerPlan.SeatMinimum!.Value));
// Being below the seat minimum means no purchased seats.
await sutProvider.GetDependency<IProviderPlanRepository>().Received(1).ReplaceAsync(Arg.Is<ProviderPlan>(
@ -924,11 +995,15 @@ public class ProviderBillingServiceTests
{
provider.GatewaySubscriptionId = null;
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider).Returns(new Customer
{
Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
});
sutProvider.GetDependency<ISubscriberService>()
.GetCustomerOrThrow(
provider,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")))
.Returns(new Customer
{
Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
});
var providerPlans = new List<ProviderPlan>
{
@ -973,13 +1048,18 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider,
Provider provider)
{
provider.Type = ProviderType.Msp;
provider.GatewaySubscriptionId = null;
sutProvider.GetDependency<ISubscriberService>().GetCustomerOrThrow(provider).Returns(new Customer
var customer = new Customer
{
Id = "customer_id",
Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported }
});
};
sutProvider.GetDependency<ISubscriberService>()
.GetCustomerOrThrow(
provider,
Arg.Is<CustomerGetOptions>(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))).Returns(customer);
var providerPlans = new List<ProviderPlan>
{
@ -1012,11 +1092,21 @@ public class ProviderBillingServiceTests
sutProvider.GetDependency<IProviderPlanRepository>().GetByProviderId(provider.Id)
.Returns(providerPlans);
var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly);
var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly);
var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active };
sutProvider.GetDependency<IAutomaticTaxStrategy>()
.When(x => x.SetCreateOptions(
Arg.Is<SubscriptionCreateOptions>(options =>
options.Customer == "customer_id")
, Arg.Is<Customer>(p => p == customer)))
.Do(x =>
{
x.Arg<SubscriptionCreateOptions>().AutomaticTax = new SubscriptionAutomaticTaxOptions
{
Enabled = true
};
});
sutProvider.GetDependency<IStripeAdapter>().SubscriptionCreateAsync(Arg.Is<SubscriptionCreateOptions>(
sub =>
sub.AutomaticTax.Enabled == true &&
@ -1024,9 +1114,9 @@ public class ProviderBillingServiceTests
sub.Customer == "customer_id" &&
sub.DaysUntilDue == 30 &&
sub.Items.Count == 2 &&
sub.Items.ElementAt(0).Price == teamsPlan.PasswordManager.StripeProviderPortalSeatPlanId &&
sub.Items.ElementAt(0).Price == ProviderPriceAdapter.MSP.Active.Teams &&
sub.Items.ElementAt(0).Quantity == 100 &&
sub.Items.ElementAt(1).Price == enterprisePlan.PasswordManager.StripeProviderPortalSeatPlanId &&
sub.Items.ElementAt(1).Price == ProviderPriceAdapter.MSP.Active.Enterprise &&
sub.Items.ElementAt(1).Quantity == 100 &&
sub.Metadata["providerId"] == provider.Id.ToString() &&
sub.OffSession == true &&
@ -1048,8 +1138,7 @@ public class ProviderBillingServiceTests
{
// Arrange
var command = new UpdateProviderSeatMinimumsCommand(
provider.Id,
provider.GatewaySubscriptionId,
provider,
[
(PlanType.TeamsMonthly, -10),
(PlanType.EnterpriseMonthly, 50)
@ -1068,6 +1157,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.Type = ProviderType.Msp;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
@ -1097,9 +1188,7 @@ public class ProviderBillingServiceTests
}
};
stripeAdapter.ProviderSubscriptionGetAsync(
provider.GatewaySubscriptionId,
provider.Id).Returns(subscription);
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
var providerPlans = new List<ProviderPlan>
{
@ -1116,8 +1205,7 @@ public class ProviderBillingServiceTests
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var command = new UpdateProviderSeatMinimumsCommand(
provider.Id,
provider.GatewaySubscriptionId,
provider,
[
(PlanType.EnterpriseMonthly, 30),
(PlanType.TeamsMonthly, 20)
@ -1149,6 +1237,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.Type = ProviderType.Msp;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
@ -1178,7 +1268,7 @@ public class ProviderBillingServiceTests
}
};
stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription);
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
var providerPlans = new List<ProviderPlan>
{
@ -1195,8 +1285,7 @@ public class ProviderBillingServiceTests
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var command = new UpdateProviderSeatMinimumsCommand(
provider.Id,
provider.GatewaySubscriptionId,
provider,
[
(PlanType.EnterpriseMonthly, 70),
(PlanType.TeamsMonthly, 50)
@ -1228,6 +1317,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.Type = ProviderType.Msp;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
@ -1257,7 +1348,7 @@ public class ProviderBillingServiceTests
}
};
stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription);
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
var providerPlans = new List<ProviderPlan>
{
@ -1274,8 +1365,7 @@ public class ProviderBillingServiceTests
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var command = new UpdateProviderSeatMinimumsCommand(
provider.Id,
provider.GatewaySubscriptionId,
provider,
[
(PlanType.EnterpriseMonthly, 60),
(PlanType.TeamsMonthly, 60)
@ -1301,6 +1391,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.Type = ProviderType.Msp;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
@ -1330,7 +1422,7 @@ public class ProviderBillingServiceTests
}
};
stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription);
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
var providerPlans = new List<ProviderPlan>
{
@ -1347,8 +1439,7 @@ public class ProviderBillingServiceTests
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var command = new UpdateProviderSeatMinimumsCommand(
provider.Id,
provider.GatewaySubscriptionId,
provider,
[
(PlanType.EnterpriseMonthly, 80),
(PlanType.TeamsMonthly, 80)
@ -1380,6 +1471,8 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider)
{
// Arrange
provider.Type = ProviderType.Msp;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
@ -1409,7 +1502,7 @@ public class ProviderBillingServiceTests
}
};
stripeAdapter.ProviderSubscriptionGetAsync(provider.GatewaySubscriptionId, provider.Id).Returns(subscription);
sutProvider.GetDependency<ISubscriberService>().GetSubscriptionOrThrow(provider).Returns(subscription);
var providerPlans = new List<ProviderPlan>
{
@ -1426,8 +1519,7 @@ public class ProviderBillingServiceTests
providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans);
var command = new UpdateProviderSeatMinimumsCommand(
provider.Id,
provider.GatewaySubscriptionId,
provider,
[
(PlanType.EnterpriseMonthly, 70),
(PlanType.TeamsMonthly, 30)

View File

@ -0,0 +1,151 @@
using Bit.Commercial.Core.Billing;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.Billing.Enums;
using Stripe;
using Xunit;
namespace Bit.Commercial.Core.Test.Billing;
public class ProviderPriceAdapterTests
{
[Theory]
[InlineData("password-manager-provider-portal-enterprise-monthly-2024", PlanType.EnterpriseMonthly)]
[InlineData("password-manager-provider-portal-teams-monthly-2024", PlanType.TeamsMonthly)]
public void GetPriceId_MSP_Legacy_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.Msp
};
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = priceId } }
]
}
};
var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType);
Assert.Equal(result, priceId);
}
[Theory]
[InlineData("provider-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)]
[InlineData("provider-portal-teams-monthly-2025", PlanType.TeamsMonthly)]
public void GetPriceId_MSP_Active_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.Msp
};
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = priceId } }
]
}
};
var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType);
Assert.Equal(result, priceId);
}
[Theory]
[InlineData("password-manager-provider-portal-enterprise-annually-2024", PlanType.EnterpriseAnnually)]
[InlineData("password-manager-provider-portal-enterprise-monthly-2024", PlanType.EnterpriseMonthly)]
public void GetPriceId_BusinessUnit_Legacy_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.BusinessUnit
};
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = priceId } }
]
}
};
var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType);
Assert.Equal(result, priceId);
}
[Theory]
[InlineData("business-unit-portal-enterprise-annually-2025", PlanType.EnterpriseAnnually)]
[InlineData("business-unit-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)]
public void GetPriceId_BusinessUnit_Active_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.BusinessUnit
};
var subscription = new Subscription
{
Items = new StripeList<SubscriptionItem>
{
Data =
[
new SubscriptionItem { Price = new Price { Id = priceId } }
]
}
};
var result = ProviderPriceAdapter.GetPriceId(provider, subscription, planType);
Assert.Equal(result, priceId);
}
[Theory]
[InlineData("provider-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)]
[InlineData("provider-portal-teams-monthly-2025", PlanType.TeamsMonthly)]
public void GetActivePriceId_MSP_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.Msp
};
var result = ProviderPriceAdapter.GetActivePriceId(provider, planType);
Assert.Equal(result, priceId);
}
[Theory]
[InlineData("business-unit-portal-enterprise-annually-2025", PlanType.EnterpriseAnnually)]
[InlineData("business-unit-portal-enterprise-monthly-2025", PlanType.EnterpriseMonthly)]
public void GetActivePriceId_BusinessUnit_Succeeds(string priceId, PlanType planType)
{
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.BusinessUnit
};
var result = ProviderPriceAdapter.GetActivePriceId(provider, planType);
Assert.Equal(result, priceId);
}
}

View File

@ -1,9 +1,11 @@
using Bit.Commercial.Core.SecretsManager.Queries.Projects;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.SecretsManager.Repositories;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
@ -66,6 +68,9 @@ public class MaxProjectsQueryTests
SutProvider<MaxProjectsQuery> sutProvider, Organization organization)
{
organization.PlanType = planType;
sutProvider.GetDependency<IPricingClient>().GetPlan(planType).Returns(StaticStore.GetPlan(planType));
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
var (limit, overLimit) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1);
@ -106,6 +111,9 @@ public class MaxProjectsQueryTests
SutProvider<MaxProjectsQuery> sutProvider, Organization organization)
{
organization.PlanType = planType;
sutProvider.GetDependency<IPricingClient>().GetPlan(planType).Returns(StaticStore.GetPlan(planType));
sutProvider.GetDependency<IOrganizationRepository>().GetByIdAsync(organization.Id).Returns(organization);
sutProvider.GetDependency<IProjectRepository>().GetProjectCountByOrganizationIdAsync(organization.Id)
.Returns(projects);

View File

@ -20,6 +20,7 @@ public class GroupsControllerPatchTests : IClassFixture<ScimApplicationFactory>,
{
var databaseContext = _factory.GetDatabaseContext();
_factory.ReinitializeDbForTests(databaseContext);
return Task.CompletedTask;
}

View File

@ -1,251 +0,0 @@
using System.Text.Json;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Services;
using Bit.Scim.Groups.Interfaces;
using Bit.Scim.IntegrationTest.Factories;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
using Bit.Test.Common.Helpers;
using NSubstitute;
using NSubstitute.ExceptionExtensions;
using Xunit;
namespace Bit.Scim.IntegrationTest.Controllers.v2;
public class GroupsControllerPatchTestsvNext : IClassFixture<ScimApplicationFactory>, IAsyncLifetime
{
private readonly ScimApplicationFactory _factory;
public GroupsControllerPatchTestsvNext(ScimApplicationFactory factory)
{
_factory = factory;
// Enable the feature flag for new PatchGroupsCommand and stub out the old command to be safe
_factory.SubstituteService((IFeatureService featureService)
=> featureService.IsEnabled(FeatureFlagKeys.ShortcutDuplicatePatchRequests).Returns(true));
_factory.SubstituteService((IPatchGroupCommand patchGroupCommand)
=> patchGroupCommand.PatchGroupAsync(Arg.Any<Organization>(), Arg.Any<Guid>(), Arg.Any<ScimPatchModel>())
.ThrowsAsync(new Exception("This test suite should be testing the vNext command, but the existing command was called.")));
}
public Task InitializeAsync()
{
var databaseContext = _factory.GetDatabaseContext();
_factory.ReinitializeDbForTests(databaseContext);
return Task.CompletedTask;
}
Task IAsyncLifetime.DisposeAsync() => Task.CompletedTask;
[Fact]
public async Task Patch_ReplaceDisplayName_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var newDisplayName = "Patch Display Name";
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{newDisplayName}\"}}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
var group = databaseContext.Groups.FirstOrDefault(g => g.Id == groupId);
Assert.Equal(newDisplayName, group.Name);
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount, databaseContext.GroupUsers.Count());
Assert.True(databaseContext.GroupUsers.Any(gu => gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId1));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId4));
}
[Fact]
public async Task Patch_ReplaceMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "replace",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId2}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Single(databaseContext.GroupUsers);
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount - 1, databaseContext.GroupUsers.Count());
var groupUser = databaseContext.GroupUsers.FirstOrDefault();
Assert.Equal(ScimApplicationFactory.TestOrganizationUserId2, groupUser.OrganizationUserId);
}
[Fact]
public async Task Patch_AddSingleMember_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "add",
Path = $"members[value eq \"{ScimApplicationFactory.TestOrganizationUserId2}\"]",
Value = JsonDocument.Parse("{}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount + 1, databaseContext.GroupUsers.Count());
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId1));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId2));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId4));
}
[Fact]
public async Task Patch_AddListMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId2;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "add",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId2}\"}},{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId3}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId2));
Assert.True(databaseContext.GroupUsers.Any(gu => gu.GroupId == groupId && gu.OrganizationUserId == ScimApplicationFactory.TestOrganizationUserId3));
}
[Fact]
public async Task Patch_RemoveSingleMember_ReplaceDisplayName_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var newDisplayName = "Patch Display Name";
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "remove",
Path = $"members[value eq \"{ScimApplicationFactory.TestOrganizationUserId1}\"]",
Value = JsonDocument.Parse("{}").RootElement
},
new ScimPatchModel.OperationModel
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{newDisplayName}\"}}").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Equal(ScimApplicationFactory.InitialGroupUsersCount - 1, databaseContext.GroupUsers.Count());
Assert.Equal(ScimApplicationFactory.InitialGroupCount, databaseContext.Groups.Count());
var group = databaseContext.Groups.FirstOrDefault(g => g.Id == groupId);
Assert.Equal(newDisplayName, group.Name);
}
[Fact]
public async Task Patch_RemoveListMembers_Success()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = ScimApplicationFactory.TestGroupId1;
var inputModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>()
{
new ScimPatchModel.OperationModel
{
Op = "remove",
Path = "members",
Value = JsonDocument.Parse($"[{{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId1}\"}}, {{\"value\":\"{ScimApplicationFactory.TestOrganizationUserId4}\"}}]").RootElement
}
},
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var databaseContext = _factory.GetDatabaseContext();
Assert.Empty(databaseContext.GroupUsers);
}
[Fact]
public async Task Patch_NotFound()
{
var organizationId = ScimApplicationFactory.TestOrganizationId1;
var groupId = Guid.NewGuid();
var inputModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>(),
Schemas = new List<string>() { ScimConstants.Scim2SchemaGroup }
};
var expectedResponse = new ScimErrorResponseModel
{
Status = StatusCodes.Status404NotFound,
Detail = "Group not found.",
Schemas = new List<string> { ScimConstants.Scim2SchemaError }
};
var context = await _factory.GroupsPatchAsync(organizationId, groupId, inputModel);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
var responseModel = JsonSerializer.Deserialize<ScimErrorResponseModel>(context.Response.Body, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
AssertHelper.AssertPropertyEqual(expectedResponse, responseModel);
}
}

View File

@ -1,9 +1,12 @@
using System.Text.Json;
using Bit.Core;
using Bit.Core.Enums;
using Bit.Core.Services;
using Bit.Scim.IntegrationTest.Factories;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
using Bit.Test.Common.Helpers;
using NSubstitute;
using Xunit;
namespace Bit.Scim.IntegrationTest.Controllers.v2;
@ -276,9 +279,18 @@ public class UsersControllerTests : IClassFixture<ScimApplicationFactory>, IAsyn
AssertHelper.AssertPropertyEqual(expectedResponse, responseModel);
}
[Fact]
public async Task Post_Success()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task Post_Success(bool isScimInviteUserOptimizationEnabled)
{
var localFactory = new ScimApplicationFactory();
localFactory.SubstituteService((IFeatureService featureService)
=> featureService.IsEnabled(FeatureFlagKeys.ScimInviteUserOptimization)
.Returns(isScimInviteUserOptimizationEnabled));
localFactory.ReinitializeDbForTests(localFactory.GetDatabaseContext());
var email = "user5@example.com";
var displayName = "Test User 5";
var externalId = "UE";
@ -306,7 +318,7 @@ public class UsersControllerTests : IClassFixture<ScimApplicationFactory>, IAsyn
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
var context = await _factory.UsersPostAsync(ScimApplicationFactory.TestOrganizationId1, inputModel);
var context = await localFactory.UsersPostAsync(ScimApplicationFactory.TestOrganizationId1, inputModel);
Assert.Equal(StatusCodes.Status201Created, context.Response.StatusCode);
@ -316,7 +328,7 @@ public class UsersControllerTests : IClassFixture<ScimApplicationFactory>, IAsyn
var responseModel = JsonSerializer.Deserialize<ScimUserResponseModel>(context.Response.Body, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase });
AssertHelper.AssertPropertyEqual(expectedResponse, responseModel, "Id");
var databaseContext = _factory.GetDatabaseContext();
var databaseContext = localFactory.GetDatabaseContext();
Assert.Equal(_initialUserCount + 1, databaseContext.OrganizationUsers.Count());
}

View File

@ -1,15 +1,18 @@
using System.Text.Json;
using AutoFixture;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Scim.Groups;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
@ -20,19 +23,16 @@ public class PatchGroupCommandTests
{
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceListMembers_Success(SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group, IEnumerable<Guid> userIds)
public async Task PatchGroup_ReplaceListMembers_Success(SutProvider<PatchGroupCommand> sutProvider,
Organization organization, Group group, IEnumerable<Guid> userIds)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetByIdAsync(group.Id)
.Returns(group);
var scimPatchModel = new Models.ScimPatchModel
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new ScimPatchModel.OperationModel
new()
{
Op = "replace",
Path = "members",
@ -42,26 +42,31 @@ public class PatchGroupCommandTests
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(organization, group.Id, scimPatchModel);
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).UpdateUsersAsync(group.Id, Arg.Is<IEnumerable<Guid>>(arg => arg.All(id => userIds.Contains(id))));
await sutProvider.GetDependency<IGroupRepository>().Received(1).UpdateUsersAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == userIds.Count() &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceDisplayNameFromPath_Success(SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group, string displayName)
public async Task PatchGroup_ReplaceDisplayNameFromPath_Success(
SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group, string displayName)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetByIdAsync(group.Id)
.Returns(group);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
var scimPatchModel = new Models.ScimPatchModel
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new ScimPatchModel.OperationModel
new()
{
Op = "replace",
Path = "displayname",
@ -71,27 +76,55 @@ public class PatchGroupCommandTests
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(organization, group.Id, scimPatchModel);
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IUpdateGroupCommand>().Received(1).UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
Assert.Equal(displayName, group.Name);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceDisplayNameFromPath_MissingOrganization_Throws(
SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group, string displayName)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns((Organization)null);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "replace",
Path = "displayname",
Value = JsonDocument.Parse($"\"{displayName}\"").RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.PatchGroupAsync(group, scimPatchModel));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceDisplayNameFromValueObject_Success(SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group, string displayName)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetByIdAsync(group.Id)
.Returns(group);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
var scimPatchModel = new Models.ScimPatchModel
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new ScimPatchModel.OperationModel
new()
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{displayName}\"}}").RootElement
@ -100,12 +133,39 @@ public class PatchGroupCommandTests
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(organization, group.Id, scimPatchModel);
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IUpdateGroupCommand>().Received(1).UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
Assert.Equal(displayName, group.Name);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceDisplayNameFromValueObject_MissingOrganization_Throws(
SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group, string displayName)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns((Organization)null);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{displayName}\"}}").RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.PatchGroupAsync(group, scimPatchModel));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddSingleMember_Success(SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group, ICollection<Guid> existingMembers, Guid userId)
@ -113,18 +173,14 @@ public class PatchGroupCommandTests
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetByIdAsync(group.Id)
.Returns(group);
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id)
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new Models.ScimPatchModel
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new ScimPatchModel.OperationModel
new()
{
Op = "add",
Path = $"members[value eq \"{userId}\"]",
@ -133,9 +189,47 @@ public class PatchGroupCommandTests
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(organization, group.Id, scimPatchModel);
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).UpdateUsersAsync(group.Id, Arg.Is<IEnumerable<Guid>>(arg => arg.All(id => existingMembers.Append(userId).Contains(id))));
await sutProvider.GetDependency<IGroupRepository>().Received(1).AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg => arg.Single() == userId));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddSingleMember_ReturnsEarlyIfAlreadyInGroup(
SutProvider<PatchGroupCommand> sutProvider,
Organization organization,
Group group,
ICollection<Guid> existingMembers)
{
// User being added is already in group
var userId = existingMembers.First();
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members[value eq \"{userId}\"]",
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>()
.DidNotReceiveWithAnyArgs()
.AddGroupUsersByIdAsync(default, default);
}
[Theory]
@ -145,18 +239,14 @@ public class PatchGroupCommandTests
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetByIdAsync(group.Id)
.Returns(group);
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id)
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new Models.ScimPatchModel
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new ScimPatchModel.OperationModel
new()
{
Op = "add",
Path = $"members",
@ -166,9 +256,101 @@ public class PatchGroupCommandTests
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(organization, group.Id, scimPatchModel);
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).UpdateUsersAsync(group.Id, Arg.Is<IEnumerable<Guid>>(arg => arg.All(id => existingMembers.Concat(userIds).Contains(id))));
await sutProvider.GetDependency<IGroupRepository>().Received(1).AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == userIds.Count &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddListMembers_IgnoresDuplicatesInRequest(
SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group,
ICollection<Guid> existingMembers)
{
// Create 3 userIds
var fixture = new Fixture { RepeatCount = 3 };
var userIds = fixture.CreateMany<Guid>().ToList();
// Copy the list and add a duplicate
var userIdsWithDuplicate = userIds.Append(userIds.First()).ToList();
Assert.Equal(4, userIdsWithDuplicate.Count);
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer
.Serialize(userIdsWithDuplicate
.Select(uid => new { value = uid })
.ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == 3 &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddListMembers_SuccessIfOnlySomeUsersAreInGroup(
SutProvider<PatchGroupCommand> sutProvider,
Organization organization, Group group,
ICollection<Guid> existingMembers,
ICollection<Guid> userIds)
{
// A user is already in the group, but some still need to be added
userIds.Add(existingMembers.First());
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer.Serialize(userIds.Select(uid => new { value = uid }).ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>()
.Received(1)
.AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == userIds.Count &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
@ -177,10 +359,6 @@ public class PatchGroupCommandTests
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetByIdAsync(group.Id)
.Returns(group);
var scimPatchModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
@ -194,21 +372,19 @@ public class PatchGroupCommandTests
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(organization, group.Id, scimPatchModel);
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupService>().Received(1).DeleteUserAsync(group, userId, EventSystemUser.SCIM);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_RemoveListMembers_Success(SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group, ICollection<Guid> existingMembers)
public async Task PatchGroup_RemoveListMembers_Success(SutProvider<PatchGroupCommand> sutProvider,
Organization organization, Group group, ICollection<Guid> existingMembers)
{
List<Guid> usersToRemove = [existingMembers.First(), existingMembers.Skip(1).First()];
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetByIdAsync(group.Id)
.Returns(group);
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id)
.Returns(existingMembers);
@ -217,30 +393,58 @@ public class PatchGroupCommandTests
{
Operations = new List<ScimPatchModel.OperationModel>
{
new ScimPatchModel.OperationModel
new()
{
Op = "remove",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer.Serialize(existingMembers.Select(uid => new { value = uid }).ToArray())).RootElement
Value = JsonDocument.Parse(JsonSerializer.Serialize(usersToRemove.Select(uid => new { value = uid }).ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(organization, group.Id, scimPatchModel);
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).UpdateUsersAsync(group.Id, Arg.Is<IEnumerable<Guid>>(arg => arg.All(id => existingMembers.Contains(id))));
var expectedRemainingUsers = existingMembers.Skip(2).ToList();
await sutProvider.GetDependency<IGroupRepository>()
.Received(1)
.UpdateUsersAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == expectedRemainingUsers.Count &&
arg.ToHashSet().SetEquals(expectedRemainingUsers)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_NoAction_Success(SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group)
public async Task PatchGroup_InvalidOperation_Success(SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetByIdAsync(group.Id)
.Returns(group);
var scimPatchModel = new Models.ScimPatchModel
{
Operations = [new ScimPatchModel.OperationModel { Op = "invalid operation" }],
Schemas = [ScimConstants.Scim2SchemaUser]
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
// Assert: no operation performed
await sutProvider.GetDependency<IGroupRepository>().DidNotReceiveWithAnyArgs().UpdateUsersAsync(default, default);
await sutProvider.GetDependency<IGroupRepository>().DidNotReceiveWithAnyArgs().GetManyUserIdsByIdAsync(default);
await sutProvider.GetDependency<IUpdateGroupCommand>().DidNotReceiveWithAnyArgs().UpdateGroupAsync(default, default);
await sutProvider.GetDependency<IGroupService>().DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default);
// Assert: logging
sutProvider.GetDependency<ILogger<PatchGroupCommand>>().ReceivedWithAnyArgs().LogWarning(default);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_NoOperation_Success(
SutProvider<PatchGroupCommand> sutProvider, Organization organization, Group group)
{
group.OrganizationId = organization.Id;
var scimPatchModel = new Models.ScimPatchModel
{
@ -248,45 +452,11 @@ public class PatchGroupCommandTests
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(organization, group.Id, scimPatchModel);
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().DidNotReceiveWithAnyArgs().UpdateUsersAsync(default, default);
await sutProvider.GetDependency<IGroupRepository>().DidNotReceiveWithAnyArgs().GetManyUserIdsByIdAsync(default);
await sutProvider.GetDependency<IUpdateGroupCommand>().DidNotReceiveWithAnyArgs().UpdateGroupAsync(default, default);
await sutProvider.GetDependency<IGroupService>().DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_NotFound_Throws(SutProvider<PatchGroupCommand> sutProvider, Organization organization, Guid groupId)
{
var scimPatchModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>(),
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.PatchGroupAsync(organization, groupId, scimPatchModel));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_MismatchingOrganizationId_Throws(SutProvider<PatchGroupCommand> sutProvider, Organization organization, Guid groupId)
{
var scimPatchModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>(),
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
sutProvider.GetDependency<IGroupRepository>()
.GetByIdAsync(groupId)
.Returns(new Group
{
Id = groupId,
OrganizationId = Guid.NewGuid()
});
await Assert.ThrowsAsync<NotFoundException>(async () => await sutProvider.Sut.PatchGroupAsync(organization, groupId, scimPatchModel));
}
}

View File

@ -1,381 +0,0 @@
using System.Text.Json;
using AutoFixture;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Scim.Groups;
using Bit.Scim.Models;
using Bit.Scim.Utilities;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Scim.Test.Groups;
[SutProviderCustomize]
public class PatchGroupCommandvNextTests
{
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceListMembers_Success(SutProvider<PatchGroupCommandvNext> sutProvider,
Organization organization, Group group, IEnumerable<Guid> userIds)
{
group.OrganizationId = organization.Id;
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "replace",
Path = "members",
Value = JsonDocument.Parse(JsonSerializer.Serialize(userIds.Select(uid => new { value = uid }).ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).UpdateUsersAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == userIds.Count() &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceDisplayNameFromPath_Success(
SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group, string displayName)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "replace",
Path = "displayname",
Value = JsonDocument.Parse($"\"{displayName}\"").RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IUpdateGroupCommand>().Received(1).UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
Assert.Equal(displayName, group.Name);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_ReplaceDisplayNameFromValueObject_Success(SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group, string displayName)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "replace",
Value = JsonDocument.Parse($"{{\"displayName\":\"{displayName}\"}}").RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IUpdateGroupCommand>().Received(1).UpdateGroupAsync(group, organization, EventSystemUser.SCIM);
Assert.Equal(displayName, group.Name);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddSingleMember_Success(SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group, ICollection<Guid> existingMembers, Guid userId)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members[value eq \"{userId}\"]",
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg => arg.Single() == userId));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddSingleMember_ReturnsEarlyIfAlreadyInGroup(
SutProvider<PatchGroupCommandvNext> sutProvider,
Organization organization,
Group group,
ICollection<Guid> existingMembers)
{
// User being added is already in group
var userId = existingMembers.First();
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members[value eq \"{userId}\"]",
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>()
.DidNotReceiveWithAnyArgs()
.AddGroupUsersByIdAsync(default, default);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddListMembers_Success(SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group, ICollection<Guid> existingMembers, ICollection<Guid> userIds)
{
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer.Serialize(userIds.Select(uid => new { value = uid }).ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == userIds.Count &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddListMembers_IgnoresDuplicatesInRequest(
SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group,
ICollection<Guid> existingMembers)
{
// Create 3 userIds
var fixture = new Fixture { RepeatCount = 3 };
var userIds = fixture.CreateMany<Guid>().ToList();
// Copy the list and add a duplicate
var userIdsWithDuplicate = userIds.Append(userIds.First()).ToList();
Assert.Equal(4, userIdsWithDuplicate.Count);
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer
.Serialize(userIdsWithDuplicate
.Select(uid => new { value = uid })
.ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().Received(1).AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == 3 &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_AddListMembers_SuccessIfOnlySomeUsersAreInGroup(
SutProvider<PatchGroupCommandvNext> sutProvider,
Organization organization, Group group,
ICollection<Guid> existingMembers,
ICollection<Guid> userIds)
{
// A user is already in the group, but some still need to be added
userIds.Add(existingMembers.First());
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id, true)
.Returns(existingMembers);
var scimPatchModel = new ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "add",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer.Serialize(userIds.Select(uid => new { value = uid }).ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>()
.Received(1)
.AddGroupUsersByIdAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == userIds.Count &&
arg.ToHashSet().SetEquals(userIds)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_RemoveSingleMember_Success(SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group, Guid userId)
{
group.OrganizationId = organization.Id;
var scimPatchModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new ScimPatchModel.OperationModel
{
Op = "remove",
Path = $"members[value eq \"{userId}\"]",
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupService>().Received(1).DeleteUserAsync(group, userId, EventSystemUser.SCIM);
}
[Theory]
[BitAutoData]
public async Task PatchGroup_RemoveListMembers_Success(SutProvider<PatchGroupCommandvNext> sutProvider,
Organization organization, Group group, ICollection<Guid> existingMembers)
{
List<Guid> usersToRemove = [existingMembers.First(), existingMembers.Skip(1).First()];
group.OrganizationId = organization.Id;
sutProvider.GetDependency<IGroupRepository>()
.GetManyUserIdsByIdAsync(group.Id)
.Returns(existingMembers);
var scimPatchModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>
{
new()
{
Op = "remove",
Path = $"members",
Value = JsonDocument.Parse(JsonSerializer.Serialize(usersToRemove.Select(uid => new { value = uid }).ToArray())).RootElement
}
},
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
var expectedRemainingUsers = existingMembers.Skip(2).ToList();
await sutProvider.GetDependency<IGroupRepository>()
.Received(1)
.UpdateUsersAsync(
group.Id,
Arg.Is<IEnumerable<Guid>>(arg =>
arg.Count() == expectedRemainingUsers.Count &&
arg.ToHashSet().SetEquals(expectedRemainingUsers)));
}
[Theory]
[BitAutoData]
public async Task PatchGroup_NoAction_Success(
SutProvider<PatchGroupCommandvNext> sutProvider, Organization organization, Group group)
{
group.OrganizationId = organization.Id;
var scimPatchModel = new Models.ScimPatchModel
{
Operations = new List<ScimPatchModel.OperationModel>(),
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
};
await sutProvider.Sut.PatchGroupAsync(group, scimPatchModel);
await sutProvider.GetDependency<IGroupRepository>().DidNotReceiveWithAnyArgs().UpdateUsersAsync(default, default);
await sutProvider.GetDependency<IGroupRepository>().DidNotReceiveWithAnyArgs().GetManyUserIdsByIdAsync(default);
await sutProvider.GetDependency<IUpdateGroupCommand>().DidNotReceiveWithAnyArgs().UpdateGroupAsync(default, default);
await sutProvider.GetDependency<IGroupService>().DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default);
}
}

View File

@ -1,4 +1,5 @@
using System.Text.Json;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
@ -43,7 +44,7 @@ public class PatchUserCommandTests
await sutProvider.Sut.PatchUserAsync(organizationUser.OrganizationId, organizationUser.Id, scimPatchModel);
await sutProvider.GetDependency<IOrganizationService>().Received(1).RestoreUserAsync(organizationUser, EventSystemUser.SCIM);
await sutProvider.GetDependency<IRestoreOrganizationUserCommand>().Received(1).RestoreUserAsync(organizationUser, EventSystemUser.SCIM);
}
[Theory]
@ -71,7 +72,7 @@ public class PatchUserCommandTests
await sutProvider.Sut.PatchUserAsync(organizationUser.OrganizationId, organizationUser.Id, scimPatchModel);
await sutProvider.GetDependency<IOrganizationService>().Received(1).RestoreUserAsync(organizationUser, EventSystemUser.SCIM);
await sutProvider.GetDependency<IRestoreOrganizationUserCommand>().Received(1).RestoreUserAsync(organizationUser, EventSystemUser.SCIM);
}
[Theory]
@ -147,7 +148,7 @@ public class PatchUserCommandTests
await sutProvider.Sut.PatchUserAsync(organizationUser.OrganizationId, organizationUser.Id, scimPatchModel);
await sutProvider.GetDependency<IOrganizationService>().DidNotReceiveWithAnyArgs().RestoreUserAsync(default, EventSystemUser.SCIM);
await sutProvider.GetDependency<IRestoreOrganizationUserCommand>().DidNotReceiveWithAnyArgs().RestoreUserAsync(default, EventSystemUser.SCIM);
await sutProvider.GetDependency<IOrganizationService>().DidNotReceiveWithAnyArgs().RevokeUserAsync(default, EventSystemUser.SCIM);
}

View File

@ -27,7 +27,7 @@ public class PostUserCommandTests
ExternalId = externalId,
Emails = emails,
Active = true,
Schemas = new List<string> { ScimConstants.Scim2SchemaUser }
Schemas = [ScimConstants.Scim2SchemaUser]
};
sutProvider.GetDependency<IOrganizationUserRepository>()
@ -39,13 +39,16 @@ public class PostUserCommandTests
sutProvider.GetDependency<IPaymentService>().HasSecretsManagerStandalone(organization).Returns(true);
sutProvider.GetDependency<IOrganizationService>()
.InviteUserAsync(organizationId, invitingUserId: null, EventSystemUser.SCIM,
.InviteUserAsync(organizationId,
invitingUserId: null,
EventSystemUser.SCIM,
Arg.Is<OrganizationUserInvite>(i =>
i.Emails.Single().Equals(scimUserRequestModel.PrimaryEmail.ToLowerInvariant()) &&
i.Type == OrganizationUserType.User &&
!i.Collections.Any() &&
!i.Groups.Any() &&
i.AccessSecretsManager), externalId)
i.AccessSecretsManager),
externalId)
.Returns(newUser);
var user = await sutProvider.Sut.PostUserAsync(organizationId, scimUserRequestModel);

90
perf/load/sync.js Normal file
View File

@ -0,0 +1,90 @@
import http from "k6/http";
import { check, fail } from "k6";
import { authenticate } from "./helpers/auth.js";
const IDENTITY_URL = __ENV.IDENTITY_URL;
const API_URL = __ENV.API_URL;
const CLIENT_ID = __ENV.CLIENT_ID;
const AUTH_USERNAME = __ENV.AUTH_USER_EMAIL;
const AUTH_PASSWORD = __ENV.AUTH_USER_PASSWORD_HASH;
export const options = {
ext: {
loadimpact: {
projectID: 3639465,
name: "Sync",
},
},
scenarios: {
constant_load: {
executor: "constant-arrival-rate",
rate: 30,
timeUnit: "1m", // 0.5 requests / second
duration: "10m",
preAllocatedVUs: 5,
},
ramping_load: {
executor: "ramping-arrival-rate",
startRate: 30,
timeUnit: "1m", // 0.5 requests / second to start
stages: [
{ duration: "30s", target: 30 },
{ duration: "2m", target: 75 },
{ duration: "1m", target: 60 },
{ duration: "2m", target: 100 },
{ duration: "2m", target: 90 },
{ duration: "1m", target: 120 },
{ duration: "30s", target: 150 },
{ duration: "30s", target: 60 },
{ duration: "30s", target: 0 },
],
preAllocatedVUs: 20,
},
},
thresholds: {
http_req_failed: ["rate<0.01"],
http_req_duration: ["p(95)<1200"],
},
};
export function setup() {
return authenticate(IDENTITY_URL, CLIENT_ID, AUTH_USERNAME, AUTH_PASSWORD);
}
export default function (data) {
const params = {
headers: {
Accept: "application/json",
"Content-Type": "application/json",
Authorization: `Bearer ${data.access_token}`,
"X-ClientId": CLIENT_ID,
},
tags: { name: "Sync" },
};
const excludeDomains = Math.random() > 0.5;
const syncRes = http.get(`${API_URL}/sync?excludeDomains=${excludeDomains}`, params);
if (
!check(syncRes, {
"sync status is 200": (r) => r.status === 200,
})
) {
console.error(`Sync failed with status ${syncRes.status}: ${syncRes.body}`);
fail("sync status code was *not* 200");
}
if (syncRes.status === 200) {
const syncJson = syncRes.json();
check(syncJson, {
"sync response has profile": (j) => j.profile !== undefined,
"sync response has folders": (j) => Array.isArray(j.folders),
"sync response has collections": (j) => Array.isArray(j.collections),
"sync response has ciphers": (j) => Array.isArray(j.ciphers),
"sync response has policies": (j) => Array.isArray(j.policies),
"sync response has sends": (j) => Array.isArray(j.sends),
"sync response has correct object type": (j) => j.object === "sync"
});
}
}

View File

@ -14,9 +14,6 @@
<ProjectReference Include="..\Core\Core.csproj" />
<ProjectReference Include="..\..\util\SqliteMigrations\SqliteMigrations.csproj" />
</ItemGroup>
<ItemGroup>
<Folder Include="Billing\Controllers\" />
</ItemGroup>
<Choose>
<When Condition="!$(DefineConstants.Contains('OSS'))">

View File

@ -133,10 +133,10 @@ public class ProvidersController : Controller
return View(new CreateResellerProviderModel());
}
[HttpGet("providers/create/multi-organization-enterprise")]
public IActionResult CreateMultiOrganizationEnterprise(int enterpriseMinimumSeats, string ownerEmail = null)
[HttpGet("providers/create/business-unit")]
public IActionResult CreateBusinessUnit(int enterpriseMinimumSeats, string ownerEmail = null)
{
return View(new CreateMultiOrganizationEnterpriseProviderModel
return View(new CreateBusinessUnitProviderModel
{
OwnerEmail = ownerEmail,
EnterpriseSeatMinimum = enterpriseMinimumSeats
@ -157,7 +157,7 @@ public class ProvidersController : Controller
{
ProviderType.Msp => RedirectToAction("CreateMsp"),
ProviderType.Reseller => RedirectToAction("CreateReseller"),
ProviderType.MultiOrganizationEnterprise => RedirectToAction("CreateMultiOrganizationEnterprise"),
ProviderType.BusinessUnit => RedirectToAction("CreateBusinessUnit"),
_ => View(model)
};
}
@ -198,10 +198,10 @@ public class ProvidersController : Controller
return RedirectToAction("Edit", new { id = provider.Id });
}
[HttpPost("providers/create/multi-organization-enterprise")]
[HttpPost("providers/create/business-unit")]
[ValidateAntiForgeryToken]
[RequirePermission(Permission.Provider_Create)]
public async Task<IActionResult> CreateMultiOrganizationEnterprise(CreateMultiOrganizationEnterpriseProviderModel model)
public async Task<IActionResult> CreateBusinessUnit(CreateBusinessUnitProviderModel model)
{
if (!ModelState.IsValid)
{
@ -209,7 +209,7 @@ public class ProvidersController : Controller
}
var provider = model.ToProvider();
await _createProviderCommand.CreateMultiOrganizationEnterpriseAsync(
await _createProviderCommand.CreateBusinessUnitAsync(
provider,
model.OwnerEmail,
model.Plan.Value,
@ -300,29 +300,27 @@ public class ProvidersController : Controller
{
case ProviderType.Msp:
var updateMspSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand(
provider.Id,
provider.GatewaySubscriptionId,
provider,
[
(Plan: PlanType.TeamsMonthly, SeatsMinimum: model.TeamsMonthlySeatMinimum),
(Plan: PlanType.EnterpriseMonthly, SeatsMinimum: model.EnterpriseMonthlySeatMinimum)
]);
await _providerBillingService.UpdateSeatMinimums(updateMspSeatMinimumsCommand);
break;
case ProviderType.MultiOrganizationEnterprise:
case ProviderType.BusinessUnit:
{
var existingMoePlan = providerPlans.Single();
// 1. Change the plan and take over any old values.
var changeMoePlanCommand = new ChangeProviderPlanCommand(
provider,
existingMoePlan.Id,
model.Plan!.Value,
provider.GatewaySubscriptionId);
model.Plan!.Value);
await _providerBillingService.ChangePlan(changeMoePlanCommand);
// 2. Update the seat minimums.
var updateMoeSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand(
provider.Id,
provider.GatewaySubscriptionId,
provider,
[
(Plan: model.Plan!.Value, SeatsMinimum: model.EnterpriseMinimumSeats!.Value)
]);

View File

@ -6,7 +6,7 @@ using Bit.SharedWeb.Utilities;
namespace Bit.Admin.AdminConsole.Models;
public class CreateMultiOrganizationEnterpriseProviderModel : IValidatableObject
public class CreateBusinessUnitProviderModel : IValidatableObject
{
[Display(Name = "Owner Email")]
public string OwnerEmail { get; set; }
@ -22,7 +22,7 @@ public class CreateMultiOrganizationEnterpriseProviderModel : IValidatableObject
{
return new Provider
{
Type = ProviderType.MultiOrganizationEnterprise
Type = ProviderType.BusinessUnit
};
}
@ -30,17 +30,17 @@ public class CreateMultiOrganizationEnterpriseProviderModel : IValidatableObject
{
if (string.IsNullOrWhiteSpace(OwnerEmail))
{
var ownerEmailDisplayName = nameof(OwnerEmail).GetDisplayAttribute<CreateMultiOrganizationEnterpriseProviderModel>()?.GetName() ?? nameof(OwnerEmail);
var ownerEmailDisplayName = nameof(OwnerEmail).GetDisplayAttribute<CreateBusinessUnitProviderModel>()?.GetName() ?? nameof(OwnerEmail);
yield return new ValidationResult($"The {ownerEmailDisplayName} field is required.");
}
if (EnterpriseSeatMinimum < 0)
{
var enterpriseSeatMinimumDisplayName = nameof(EnterpriseSeatMinimum).GetDisplayAttribute<CreateMultiOrganizationEnterpriseProviderModel>()?.GetName() ?? nameof(EnterpriseSeatMinimum);
var enterpriseSeatMinimumDisplayName = nameof(EnterpriseSeatMinimum).GetDisplayAttribute<CreateBusinessUnitProviderModel>()?.GetName() ?? nameof(EnterpriseSeatMinimum);
yield return new ValidationResult($"The {enterpriseSeatMinimumDisplayName} field can not be negative.");
}
if (Plan != PlanType.EnterpriseAnnually && Plan != PlanType.EnterpriseMonthly)
{
var planDisplayName = nameof(Plan).GetDisplayAttribute<CreateMultiOrganizationEnterpriseProviderModel>()?.GetName() ?? nameof(Plan);
var planDisplayName = nameof(Plan).GetDisplayAttribute<CreateBusinessUnitProviderModel>()?.GetName() ?? nameof(Plan);
yield return new ValidationResult($"The {planDisplayName} field must be set to Enterprise Annually or Enterprise Monthly.");
}
}

View File

@ -86,6 +86,7 @@ public class OrganizationEditModel : OrganizationViewModel
UseApi = org.UseApi;
UseSecretsManager = org.UseSecretsManager;
UseRiskInsights = org.UseRiskInsights;
UseAdminSponsoredFamilies = org.UseAdminSponsoredFamilies;
UseResetPassword = org.UseResetPassword;
SelfHost = org.SelfHost;
UsersGetPremium = org.UsersGetPremium;
@ -154,6 +155,8 @@ public class OrganizationEditModel : OrganizationViewModel
public new bool UseSecretsManager { get; set; }
[Display(Name = "Risk Insights")]
public new bool UseRiskInsights { get; set; }
[Display(Name = "Admin Sponsored Families")]
public bool UseAdminSponsoredFamilies { get; set; }
[Display(Name = "Self Host")]
public bool SelfHost { get; set; }
[Display(Name = "Users Get Premium")]
@ -295,6 +298,7 @@ public class OrganizationEditModel : OrganizationViewModel
existingOrganization.UseApi = UseApi;
existingOrganization.UseSecretsManager = UseSecretsManager;
existingOrganization.UseRiskInsights = UseRiskInsights;
existingOrganization.UseAdminSponsoredFamilies = UseAdminSponsoredFamilies;
existingOrganization.UseResetPassword = UseResetPassword;
existingOrganization.SelfHost = SelfHost;
existingOrganization.UsersGetPremium = UsersGetPremium;

View File

@ -34,7 +34,7 @@ public class ProviderEditModel : ProviderViewModel, IValidatableObject
GatewaySubscriptionUrl = gatewaySubscriptionUrl;
Type = provider.Type;
if (Type == ProviderType.MultiOrganizationEnterprise)
if (Type == ProviderType.BusinessUnit)
{
var plan = providerPlans.SingleOrDefault();
EnterpriseMinimumSeats = plan?.SeatMinimum ?? 0;
@ -100,7 +100,7 @@ public class ProviderEditModel : ProviderViewModel, IValidatableObject
yield return new ValidationResult($"The {billingEmailDisplayName} field is required.");
}
break;
case ProviderType.MultiOrganizationEnterprise:
case ProviderType.BusinessUnit:
if (Plan == null)
{
var displayName = nameof(Plan).GetDisplayAttribute<CreateProviderModel>()?.GetName() ?? nameof(Plan);

View File

@ -40,7 +40,7 @@ public class ProviderViewModel
ProviderPlanViewModels.Add(new ProviderPlanViewModel("Enterprise (Monthly) Subscription", enterpriseProviderPlan, usedEnterpriseSeats));
}
}
else if (Provider.Type == ProviderType.MultiOrganizationEnterprise)
else if (Provider.Type == ProviderType.BusinessUnit)
{
var usedEnterpriseSeats = ProviderOrganizations.Where(po => po.PlanType == PlanType.EnterpriseMonthly)
.Sum(po => po.OccupiedSeats).GetValueOrDefault(0);

View File

@ -1,8 +1,13 @@
@using Bit.Admin.Enums;
@using Bit.Admin.Models
@using Bit.Core
@using Bit.Core.AdminConsole.Enums.Provider
@using Bit.Core.Billing.Enums
@using Bit.Core.Enums
@using Bit.Core.Billing.Extensions
@using Bit.Core.Services
@using Microsoft.AspNetCore.Mvc.TagHelpers
@inject Bit.Admin.Services.IAccessControlService AccessControlService
@inject IFeatureService FeatureService
@model OrganizationEditModel
@{
ViewData["Title"] = (Model.Provider != null ? "Client " : string.Empty) + "Organization: " + Model.Name;
@ -13,6 +18,13 @@
var canRequestDelete = AccessControlService.UserHasPermission(Permission.Org_RequestDelete);
var canDelete = AccessControlService.UserHasPermission(Permission.Org_Delete);
var canUnlinkFromProvider = AccessControlService.UserHasPermission(Permission.Provider_Edit);
var canConvertToBusinessUnit =
FeatureService.IsEnabled(FeatureFlagKeys.PM18770_EnableOrganizationBusinessUnitConversion) &&
AccessControlService.UserHasPermission(Permission.Org_Billing_ConvertToBusinessUnit) &&
Model.Organization.PlanType.GetProductTier() == ProductTierType.Enterprise &&
!string.IsNullOrEmpty(Model.Organization.GatewaySubscriptionId) &&
Model.Provider is null or { Type: ProviderType.BusinessUnit, Status: ProviderStatusType.Pending };
}
@section Scripts {
@ -114,6 +126,15 @@
Enterprise Trial
</button>
}
@if (canConvertToBusinessUnit)
{
<a asp-controller="BusinessUnitConversion"
asp-action="Index"
asp-route-organizationId="@Model.Organization.Id"
class="btn btn-secondary me-2">
Convert to Business Unit
</a>
}
@if (canUnlinkFromProvider && Model.Provider is not null)
{
<button class="btn btn-outline-danger me-2"

View File

@ -1,15 +1,15 @@
@using Bit.Core.Billing.Enums
@using Microsoft.AspNetCore.Mvc.TagHelpers
@model CreateMultiOrganizationEnterpriseProviderModel
@model CreateBusinessUnitProviderModel
@{
ViewData["Title"] = "Create Multi-organization Enterprise Provider";
ViewData["Title"] = "Create Business Unit Provider";
}
<h1 class="mb-4">Create Multi-organization Enterprise Provider</h1>
<h1 class="mb-4">Create Business Unit Provider</h1>
<div>
<form method="post" asp-action="CreateMultiOrganizationEnterprise">
<form method="post" asp-action="CreateBusinessUnit">
<div asp-validation-summary="All" class="alert alert-danger"></div>
<div class="mb-3">
<label asp-for="OwnerEmail" class="form-label"></label>
@ -19,14 +19,14 @@
<div class="col-sm">
<div class="mb-3">
@{
var multiOrgPlans = new List<PlanType>
var businessUnitPlanTypes = new List<PlanType>
{
PlanType.EnterpriseAnnually,
PlanType.EnterpriseMonthly
};
}
<label asp-for="Plan" class="form-label"></label>
<select class="form-select" asp-for="Plan" asp-items="Html.GetEnumSelectList(multiOrgPlans)">
<select class="form-select" asp-for="Plan" asp-items="Html.GetEnumSelectList(businessUnitPlanTypes)">
<option value="">--</option>
</select>
</div>

View File

@ -74,20 +74,20 @@
</div>
break;
}
case ProviderType.MultiOrganizationEnterprise:
case ProviderType.BusinessUnit:
{
<div class="row">
<div class="col-sm">
<div class="mb-3">
@{
var multiOrgPlans = new List<PlanType>
var businessUnitPlanTypes = new List<PlanType>
{
PlanType.EnterpriseAnnually,
PlanType.EnterpriseMonthly
};
}
<label asp-for="Plan" class="form-label"></label>
<select class="form-control" asp-for="Plan" asp-items="Html.GetEnumSelectList(multiOrgPlans)">
<select class="form-control" asp-for="Plan" asp-items="Html.GetEnumSelectList(businessUnitPlanTypes)">
<option value="">--</option>
</select>
</div>

View File

@ -0,0 +1,185 @@
#nullable enable
using Bit.Admin.Billing.Models;
using Bit.Admin.Enums;
using Bit.Admin.Utilities;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Billing.Controllers;
[Authorize]
[Route("organizations/billing/{organizationId:guid}/business-unit")]
[RequireFeature(FeatureFlagKeys.PM18770_EnableOrganizationBusinessUnitConversion)]
public class BusinessUnitConversionController(
IBusinessUnitConverter businessUnitConverter,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
IProviderUserRepository providerUserRepository) : Controller
{
[HttpGet]
[RequirePermission(Permission.Org_Billing_ConvertToBusinessUnit)]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> IndexAsync([FromRoute] Guid organizationId)
{
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
var model = new BusinessUnitConversionModel { Organization = organization };
var invitedProviderAdmin = await GetInvitedProviderAdminAsync(organization);
if (invitedProviderAdmin != null)
{
model.ProviderAdminEmail = invitedProviderAdmin.Email;
model.ProviderId = invitedProviderAdmin.ProviderId;
}
var success = ReadSuccessMessage();
if (!string.IsNullOrEmpty(success))
{
model.Success = success;
}
var errors = ReadErrorMessages();
if (errors is { Count: > 0 })
{
model.Errors = errors;
}
return View(model);
}
[HttpPost]
[RequirePermission(Permission.Org_Billing_ConvertToBusinessUnit)]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> InitiateAsync(
[FromRoute] Guid organizationId,
BusinessUnitConversionModel model)
{
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
var result = await businessUnitConverter.InitiateConversion(
organization,
model.ProviderAdminEmail!);
return result.Match(
providerId => RedirectToAction("Edit", "Providers", new { id = providerId }),
errors =>
{
PersistErrorMessages(errors);
return RedirectToAction("Index", new { organizationId });
});
}
[HttpPost("reset")]
[RequirePermission(Permission.Org_Billing_ConvertToBusinessUnit)]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> ResetAsync(
[FromRoute] Guid organizationId,
BusinessUnitConversionModel model)
{
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
await businessUnitConverter.ResetConversion(organization, model.ProviderAdminEmail!);
PersistSuccessMessage("Business unit conversion was successfully reset.");
return RedirectToAction("Index", new { organizationId });
}
[HttpPost("resend-invite")]
[RequirePermission(Permission.Org_Billing_ConvertToBusinessUnit)]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> ResendInviteAsync(
[FromRoute] Guid organizationId,
BusinessUnitConversionModel model)
{
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
await businessUnitConverter.ResendConversionInvite(organization, model.ProviderAdminEmail!);
PersistSuccessMessage($"Invite was successfully resent to {model.ProviderAdminEmail}.");
return RedirectToAction("Index", new { organizationId });
}
private async Task<ProviderUser?> GetInvitedProviderAdminAsync(
Organization organization)
{
var provider = await providerRepository.GetByOrganizationIdAsync(organization.Id);
if (provider is not
{
Type: ProviderType.BusinessUnit,
Status: ProviderStatusType.Pending
})
{
return null;
}
var providerUsers =
await providerUserRepository.GetManyByProviderAsync(provider.Id, ProviderUserType.ProviderAdmin);
if (providerUsers.Count != 1)
{
return null;
}
var providerUser = providerUsers.First();
return providerUser is
{
Type: ProviderUserType.ProviderAdmin,
Status: ProviderUserStatusType.Invited,
UserId: not null
} ? providerUser : null;
}
private const string _errors = "errors";
private const string _success = "Success";
private void PersistSuccessMessage(string message) => TempData[_success] = message;
private void PersistErrorMessages(List<string> errors)
{
var input = string.Join("|", errors);
TempData[_errors] = input;
}
private string? ReadSuccessMessage() => ReadTempData<string>(_success);
private List<string>? ReadErrorMessages()
{
var output = ReadTempData<string>(_errors);
return string.IsNullOrEmpty(output) ? null : output.Split('|').ToList();
}
private T? ReadTempData<T>(string key) => TempData.TryGetValue(key, out var obj) && obj is T value ? value : default;
}

View File

@ -0,0 +1,25 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Core.AdminConsole.Entities;
using Microsoft.AspNetCore.Mvc.ModelBinding;
namespace Bit.Admin.Billing.Models;
public class BusinessUnitConversionModel
{
[Required]
[EmailAddress]
[Display(Name = "Provider Admin Email")]
public string? ProviderAdminEmail { get; set; }
[BindNever]
public required Organization Organization { get; set; }
[BindNever]
public Guid? ProviderId { get; set; }
[BindNever]
public string? Success { get; set; }
[BindNever] public List<string>? Errors { get; set; } = [];
}

View File

@ -0,0 +1,75 @@
@model Bit.Admin.Billing.Models.BusinessUnitConversionModel
@{
ViewData["Title"] = "Convert Organization to Business Unit";
}
@if (!string.IsNullOrEmpty(Model.ProviderAdminEmail))
{
<h1>Convert @Model.Organization.Name to Business Unit</h1>
@if (!string.IsNullOrEmpty(Model.Success))
{
<div class="alert alert-success alert-dismissible fade show mb-3" role="alert">
@Model.Success
<button type="button" class="btn-close" data-bs-dismiss="alert" aria-label="Close"></button>
</div>
}
@if (Model.Errors?.Any() ?? false)
{
@foreach (var error in Model.Errors)
{
<div class="alert alert-danger alert-dismissible fade show mb-3" role="alert">
@error
<button type="button" class="btn-close" data-bs-dismiss="alert" aria-label="Close"></button>
</div>
}
}
<p>This organization has a business unit conversion in progress.</p>
<div class="mb-3">
<label asp-for="ProviderAdminEmail" class="form-label"></label>
<input type="email" class="form-control" asp-for="ProviderAdminEmail" disabled></input>
</div>
<div class="d-flex gap-2">
<form method="post" asp-controller="BusinessUnitConversion" asp-action="ResendInvite" asp-route-organizationId="@Model.Organization.Id">
<input type="hidden" asp-for="ProviderAdminEmail" />
<button type="submit" class="btn btn-primary mb-2">Resend Invite</button>
</form>
<form method="post" asp-controller="BusinessUnitConversion" asp-action="Reset" asp-route-organizationId="@Model.Organization.Id">
<input type="hidden" asp-for="ProviderAdminEmail" />
<button type="submit" class="btn btn-danger mb-2">Reset Conversion</button>
</form>
@if (Model.ProviderId.HasValue)
{
<a asp-controller="Providers"
asp-action="Edit"
asp-route-id="@Model.ProviderId"
class="btn btn-secondary mb-2">
Go to Provider
</a>
}
</div>
}
else
{
<h1>Convert @Model.Organization.Name to Business Unit</h1>
@if (Model.Errors?.Any() ?? false)
{
@foreach (var error in Model.Errors)
{
<div class="alert alert-danger alert-dismissible fade show mb-3" role="alert">
@error
<button type="button" class="btn-close" data-bs-dismiss="alert" aria-label="Close"></button>
</div>
}
}
<form method="post" asp-controller="BusinessUnitConversion" asp-action="Initiate" asp-route-organizationId="@Model.Organization.Id">
<div asp-validation-summary="All" class="alert alert-danger"></div>
<div class="mb-3">
<label asp-for="ProviderAdminEmail" class="form-label"></label>
<input type="email" class="form-control" asp-for="ProviderAdminEmail" />
</div>
<button type="submit" class="btn btn-primary mb-2">Convert</button>
</form>
}

View File

@ -184,7 +184,7 @@ public class UsersController : Controller
private async Task<bool?> AccountDeprovisioningEnabled(Guid userId)
{
return _featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning)
? await _userService.IsManagedByAnyOrganizationAsync(userId)
? await _userService.IsClaimedByAnyOrganizationAsync(userId)
: null;
}
}

View File

@ -38,6 +38,7 @@ public enum Permission
Org_Billing_View,
Org_Billing_Edit,
Org_Billing_LaunchGateway,
Org_Billing_ConvertToBusinessUnit,
Provider_List_View,
Provider_Create,

View File

@ -42,6 +42,7 @@ public static class RolePermissionMapping
Permission.Org_Billing_View,
Permission.Org_Billing_Edit,
Permission.Org_Billing_LaunchGateway,
Permission.Org_Billing_ConvertToBusinessUnit,
Permission.Provider_List_View,
Permission.Provider_Create,
Permission.Provider_View,
@ -90,6 +91,7 @@ public static class RolePermissionMapping
Permission.Org_Billing_View,
Permission.Org_Billing_Edit,
Permission.Org_Billing_LaunchGateway,
Permission.Org_Billing_ConvertToBusinessUnit,
Permission.Org_InitiateTrial,
Permission.Provider_List_View,
Permission.Provider_Create,
@ -166,6 +168,7 @@ public static class RolePermissionMapping
Permission.Org_Billing_View,
Permission.Org_Billing_Edit,
Permission.Org_Billing_LaunchGateway,
Permission.Org_Billing_ConvertToBusinessUnit,
Permission.Org_RequestDelete,
Permission.Provider_Edit,
Permission.Provider_View,

View File

@ -0,0 +1,21 @@
#nullable enable
using Microsoft.AspNetCore.Authorization;
namespace Bit.Api.AdminConsole.Authorization;
/// <summary>
/// An attribute which requires authorization using the specified requirement.
/// This uses the standard ASP.NET authorization middleware.
/// </summary>
/// <typeparam name="T">The IAuthorizationRequirement that will be used to authorize the user.</typeparam>
public class AuthorizeAttribute<T>
: AuthorizeAttribute, IAuthorizationRequirementData
where T : IAuthorizationRequirement, new()
{
public IEnumerable<IAuthorizationRequirement> GetRequirements()
{
var requirement = new T();
return [requirement];
}
}

View File

@ -0,0 +1,79 @@
#nullable enable
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.AdminConsole.Repositories;
namespace Bit.Api.AdminConsole.Authorization;
public static class HttpContextExtensions
{
public const string NoOrgIdError =
"A route decorated with with '[Authorize<Requirement>]' must include a route value named 'orgId' either through the [Controller] attribute or through a '[Http*]' attribute.";
/// <summary>
/// Returns the result of the callback, caching it in HttpContext.Features for the lifetime of the request.
/// Subsequent calls will retrieve the cached value.
/// Results are stored by type and therefore must be of a unique type.
/// </summary>
public static async Task<T> WithFeaturesCacheAsync<T>(this HttpContext httpContext, Func<Task<T>> callback)
{
var cachedResult = httpContext.Features.Get<T>();
if (cachedResult != null)
{
return cachedResult;
}
var result = await callback();
httpContext.Features.Set(result);
return result;
}
/// <summary>
/// Returns true if the user is a ProviderUser for a Provider which manages the specified organization, otherwise false.
/// </summary>
/// <remarks>
/// This data is fetched from the database and cached as a HttpContext Feature for the lifetime of the request.
/// </remarks>
public static async Task<bool> IsProviderUserForOrgAsync(
this HttpContext httpContext,
IProviderUserRepository providerUserRepository,
Guid userId,
Guid organizationId)
{
var organizations = await httpContext.GetProviderUserOrganizationsAsync(providerUserRepository, userId);
return organizations.Any(o => o.OrganizationId == organizationId);
}
/// <summary>
/// Returns the ProviderUserOrganizations for a user. These are the organizations the ProviderUser manages via their Provider, if any.
/// </summary>
/// <remarks>
/// This data is fetched from the database and cached as a HttpContext Feature for the lifetime of the request.
/// </remarks>
private static async Task<IEnumerable<ProviderUserOrganizationDetails>> GetProviderUserOrganizationsAsync(
this HttpContext httpContext,
IProviderUserRepository providerUserRepository,
Guid userId)
=> await httpContext.WithFeaturesCacheAsync(() =>
providerUserRepository.GetManyOrganizationDetailsByUserAsync(userId, ProviderUserStatusType.Confirmed));
/// <summary>
/// Parses the {orgId} route parameter into a Guid, or throws if the {orgId} is not present or not a valid guid.
/// </summary>
/// <param name="httpContext"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public static Guid GetOrganizationId(this HttpContext httpContext)
{
httpContext.GetRouteData().Values.TryGetValue("orgId", out var orgIdParam);
if (orgIdParam == null || !Guid.TryParse(orgIdParam.ToString(), out var orgId))
{
throw new InvalidOperationException(NoOrgIdError);
}
return orgId;
}
}

View File

@ -0,0 +1,31 @@
#nullable enable
using Bit.Core.Context;
using Microsoft.AspNetCore.Authorization;
namespace Bit.Api.AdminConsole.Authorization;
/// <summary>
/// A requirement that implements this interface will be handled by <see cref="OrganizationRequirementHandler"/>,
/// which calls AuthorizeAsync with the organization details from the route.
/// This is used for simple role-based checks.
/// This may only be used on endpoints with {orgId} in their path.
/// </summary>
public interface IOrganizationRequirement : IAuthorizationRequirement
{
/// <summary>
/// Whether to authorize a request that has this requirement.
/// </summary>
/// <param name="organizationClaims">
/// The CurrentContextOrganization for the user if they are a member of the organization.
/// This is null if they are not a member.
/// </param>
/// <param name="isProviderUserForOrg">
/// A callback that returns true if the user is a ProviderUser that manages the organization, otherwise false.
/// This requires a database query, call it last.
/// </param>
/// <returns>True if the requirement has been satisfied, otherwise false.</returns>
public Task<bool> AuthorizeAsync(
CurrentContextOrganization? organizationClaims,
Func<Task<bool>> isProviderUserForOrg);
}

View File

@ -0,0 +1,119 @@
#nullable enable
using System.Security.Claims;
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Identity;
using Bit.Core.Models.Data;
namespace Bit.Api.AdminConsole.Authorization;
public static class OrganizationClaimsExtensions
{
/// <summary>
/// Parses a user's claims and returns an object representing their claims for the specified organization.
/// </summary>
/// <param name="user">The user who has the claims.</param>
/// <param name="organizationId">The organizationId to look for in the claims.</param>
/// <returns>
/// A <see cref="CurrentContextOrganization"/> representing the user's claims for that organization, or null
/// if the user does not have any claims for that organization.
/// </returns>
public static CurrentContextOrganization? GetCurrentContextOrganization(this ClaimsPrincipal user, Guid organizationId)
{
var hasClaim = GetClaimsParser(user, organizationId);
var role = GetRoleFromClaims(hasClaim);
if (!role.HasValue)
{
// Not an organization member
return null;
}
return new CurrentContextOrganization
{
Id = organizationId,
Type = role.Value,
AccessSecretsManager = hasClaim(Claims.SecretsManagerAccess),
Permissions = role == OrganizationUserType.Custom
? GetPermissionsFromClaims(hasClaim)
: new Permissions()
};
}
/// <summary>
/// Returns a function for evaluating claims for the specified user and organizationId.
/// The function returns true if the claim type exists and false otherwise.
/// </summary>
private static Func<string, bool> GetClaimsParser(ClaimsPrincipal user, Guid organizationId)
{
// Group claims by ClaimType
var claimsDict = user.Claims
.GroupBy(c => c.Type)
.ToDictionary(
c => c.Key,
c => c.ToList());
return claimType
=> claimsDict.TryGetValue(claimType, out var claims) &&
claims
.ParseGuids()
.Any(v => v == organizationId);
}
/// <summary>
/// Parses the provided claims into proper Guids, or ignore them if they are not valid guids.
/// </summary>
private static IEnumerable<Guid> ParseGuids(this IEnumerable<Claim> claims)
{
foreach (var claim in claims)
{
if (Guid.TryParse(claim.Value, out var guid))
{
yield return guid;
}
}
}
private static OrganizationUserType? GetRoleFromClaims(Func<string, bool> hasClaim)
{
if (hasClaim(Claims.OrganizationOwner))
{
return OrganizationUserType.Owner;
}
if (hasClaim(Claims.OrganizationAdmin))
{
return OrganizationUserType.Admin;
}
if (hasClaim(Claims.OrganizationCustom))
{
return OrganizationUserType.Custom;
}
if (hasClaim(Claims.OrganizationUser))
{
return OrganizationUserType.User;
}
return null;
}
private static Permissions GetPermissionsFromClaims(Func<string, bool> hasClaim)
=> new()
{
AccessEventLogs = hasClaim(Claims.CustomPermissions.AccessEventLogs),
AccessImportExport = hasClaim(Claims.CustomPermissions.AccessImportExport),
AccessReports = hasClaim(Claims.CustomPermissions.AccessReports),
CreateNewCollections = hasClaim(Claims.CustomPermissions.CreateNewCollections),
EditAnyCollection = hasClaim(Claims.CustomPermissions.EditAnyCollection),
DeleteAnyCollection = hasClaim(Claims.CustomPermissions.DeleteAnyCollection),
ManageGroups = hasClaim(Claims.CustomPermissions.ManageGroups),
ManagePolicies = hasClaim(Claims.CustomPermissions.ManagePolicies),
ManageSso = hasClaim(Claims.CustomPermissions.ManageSso),
ManageUsers = hasClaim(Claims.CustomPermissions.ManageUsers),
ManageResetPassword = hasClaim(Claims.CustomPermissions.ManageResetPassword),
ManageScim = hasClaim(Claims.CustomPermissions.ManageScim),
};
}

View File

@ -0,0 +1,49 @@
#nullable enable
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Services;
using Microsoft.AspNetCore.Authorization;
namespace Bit.Api.AdminConsole.Authorization;
/// <summary>
/// Handles any requirement that implements <see cref="IOrganizationRequirement"/>.
/// Retrieves the Organization ID from the route and then passes it to the requirement's AuthorizeAsync callback to
/// determine whether the action is authorized.
/// </summary>
public class OrganizationRequirementHandler(
IHttpContextAccessor httpContextAccessor,
IProviderUserRepository providerUserRepository,
IUserService userService)
: AuthorizationHandler<IOrganizationRequirement>
{
public const string NoHttpContextError = "This method should only be called in the context of an HTTP Request.";
public const string NoUserIdError = "This method should only be called on the private api with a logged in user.";
protected override async Task HandleRequirementAsync(AuthorizationHandlerContext context, IOrganizationRequirement requirement)
{
var httpContext = httpContextAccessor.HttpContext;
if (httpContext == null)
{
throw new InvalidOperationException(NoHttpContextError);
}
var organizationId = httpContext.GetOrganizationId();
var organizationClaims = httpContext.User.GetCurrentContextOrganization(organizationId);
var userId = userService.GetProperUserId(httpContext.User);
if (userId == null)
{
throw new InvalidOperationException(NoUserIdError);
}
Task<bool> IsProviderUserForOrg() => httpContext.IsProviderUserForOrgAsync(providerUserRepository, userId.Value, organizationId);
var authorized = await requirement.AuthorizeAsync(organizationClaims, IsProviderUserForOrg);
if (authorized)
{
context.Succeed(requirement);
}
}
}

View File

@ -0,0 +1,20 @@
#nullable enable
using Bit.Core.Context;
using Bit.Core.Enums;
namespace Bit.Api.AdminConsole.Authorization.Requirements;
public class ManageUsersRequirement : IOrganizationRequirement
{
public async Task<bool> AuthorizeAsync(
CurrentContextOrganization? organizationClaims,
Func<Task<bool>> isProviderUserForOrg)
=> organizationClaims switch
{
{ Type: OrganizationUserType.Owner } => true,
{ Type: OrganizationUserType.Admin } => true,
{ Permissions.ManageUsers: true } => true,
_ => await isProviderUserForOrg()
};
}

View File

@ -0,0 +1,16 @@
#nullable enable
using Bit.Core.Context;
namespace Bit.Api.AdminConsole.Authorization.Requirements;
/// <summary>
/// Requires that the user is a member of the organization or a provider for the organization.
/// </summary>
public class MemberOrProviderRequirement : IOrganizationRequirement
{
public async Task<bool> AuthorizeAsync(
CurrentContextOrganization? organizationClaims,
Func<Task<bool>> isProviderUserForOrg)
=> organizationClaims is not null || await isProviderUserForOrg();
}

View File

@ -8,6 +8,9 @@ using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Authorization;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.OrganizationFeatures.Shared.Authorization;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Enums;
@ -53,10 +56,13 @@ public class OrganizationUsersController : Controller
private readonly IOrganizationUserUserDetailsQuery _organizationUserUserDetailsQuery;
private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery;
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
private readonly IDeleteManagedOrganizationUserAccountCommand _deleteManagedOrganizationUserAccountCommand;
private readonly IGetOrganizationUsersManagementStatusQuery _getOrganizationUsersManagementStatusQuery;
private readonly IDeleteClaimedOrganizationUserAccountCommand _deleteClaimedOrganizationUserAccountCommand;
private readonly IGetOrganizationUsersClaimedStatusQuery _getOrganizationUsersClaimedStatusQuery;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IFeatureService _featureService;
private readonly IPricingClient _pricingClient;
private readonly IConfirmOrganizationUserCommand _confirmOrganizationUserCommand;
private readonly IRestoreOrganizationUserCommand _restoreOrganizationUserCommand;
public OrganizationUsersController(
IOrganizationRepository organizationRepository,
@ -77,10 +83,13 @@ public class OrganizationUsersController : Controller
IOrganizationUserUserDetailsQuery organizationUserUserDetailsQuery,
ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery,
IRemoveOrganizationUserCommand removeOrganizationUserCommand,
IDeleteManagedOrganizationUserAccountCommand deleteManagedOrganizationUserAccountCommand,
IGetOrganizationUsersManagementStatusQuery getOrganizationUsersManagementStatusQuery,
IDeleteClaimedOrganizationUserAccountCommand deleteClaimedOrganizationUserAccountCommand,
IGetOrganizationUsersClaimedStatusQuery getOrganizationUsersClaimedStatusQuery,
IPolicyRequirementQuery policyRequirementQuery,
IFeatureService featureService,
IPricingClient pricingClient)
IPricingClient pricingClient,
IConfirmOrganizationUserCommand confirmOrganizationUserCommand,
IRestoreOrganizationUserCommand restoreOrganizationUserCommand)
{
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
@ -100,10 +109,13 @@ public class OrganizationUsersController : Controller
_organizationUserUserDetailsQuery = organizationUserUserDetailsQuery;
_twoFactorIsEnabledQuery = twoFactorIsEnabledQuery;
_removeOrganizationUserCommand = removeOrganizationUserCommand;
_deleteManagedOrganizationUserAccountCommand = deleteManagedOrganizationUserAccountCommand;
_getOrganizationUsersManagementStatusQuery = getOrganizationUsersManagementStatusQuery;
_deleteClaimedOrganizationUserAccountCommand = deleteClaimedOrganizationUserAccountCommand;
_getOrganizationUsersClaimedStatusQuery = getOrganizationUsersClaimedStatusQuery;
_policyRequirementQuery = policyRequirementQuery;
_featureService = featureService;
_pricingClient = pricingClient;
_confirmOrganizationUserCommand = confirmOrganizationUserCommand;
_restoreOrganizationUserCommand = restoreOrganizationUserCommand;
}
[HttpGet("{id}")]
@ -115,11 +127,11 @@ public class OrganizationUsersController : Controller
throw new NotFoundException();
}
var managedByOrganization = await GetManagedByOrganizationStatusAsync(
var claimedByOrganizationStatus = await GetClaimedByOrganizationStatusAsync(
organizationUser.OrganizationId,
[organizationUser.Id]);
var response = new OrganizationUserDetailsResponseModel(organizationUser, managedByOrganization[organizationUser.Id], collections);
var response = new OrganizationUserDetailsResponseModel(organizationUser, claimedByOrganizationStatus[organizationUser.Id], collections);
if (includeGroups)
{
@ -163,13 +175,13 @@ public class OrganizationUsersController : Controller
}
);
var organizationUsersTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(organizationUsers);
var organizationUsersManagementStatus = await GetManagedByOrganizationStatusAsync(orgId, organizationUsers.Select(o => o.Id));
var organizationUsersClaimedStatus = await GetClaimedByOrganizationStatusAsync(orgId, organizationUsers.Select(o => o.Id));
var responses = organizationUsers
.Select(o =>
{
var userTwoFactorEnabled = organizationUsersTwoFactorEnabled.FirstOrDefault(u => u.user.Id == o.Id).twoFactorIsEnabled;
var managedByOrganization = organizationUsersManagementStatus[o.Id];
var orgUser = new OrganizationUserUserDetailsResponseModel(o, userTwoFactorEnabled, managedByOrganization);
var claimedByOrganization = organizationUsersClaimedStatus[o.Id];
var orgUser = new OrganizationUserUserDetailsResponseModel(o, userTwoFactorEnabled, claimedByOrganization);
return orgUser;
});
@ -301,9 +313,9 @@ public class OrganizationUsersController : Controller
throw new UnauthorizedAccessException();
}
await _organizationService.InitPendingOrganization(user.Id, orgId, organizationUserId, model.Keys.PublicKey, model.Keys.EncryptedPrivateKey, model.CollectionName);
await _organizationService.InitPendingOrganization(user, orgId, organizationUserId, model.Keys.PublicKey, model.Keys.EncryptedPrivateKey, model.CollectionName, model.Token);
await _acceptOrgUserCommand.AcceptOrgUserByEmailTokenAsync(organizationUserId, user, model.Token, _userService);
await _organizationService.ConfirmUserAsync(orgId, organizationUserId, model.Key, user.Id);
await _confirmOrganizationUserCommand.ConfirmUserAsync(orgId, organizationUserId, model.Key, user.Id);
}
[HttpPost("{organizationUserId}/accept")]
@ -315,11 +327,13 @@ public class OrganizationUsersController : Controller
throw new UnauthorizedAccessException();
}
var useMasterPasswordPolicy = await ShouldHandleResetPasswordAsync(orgId);
var useMasterPasswordPolicy = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements)
? (await _policyRequirementQuery.GetAsync<ResetPasswordPolicyRequirement>(user.Id)).AutoEnrollEnabled(orgId)
: await ShouldHandleResetPasswordAsync(orgId);
if (useMasterPasswordPolicy && string.IsNullOrWhiteSpace(model.ResetPasswordKey))
{
throw new BadRequestException(string.Empty, "Master Password reset is required, but not provided.");
throw new BadRequestException("Master Password reset is required, but not provided.");
}
await _acceptOrgUserCommand.AcceptOrgUserByEmailTokenAsync(organizationUserId, user, model.Token, _userService);
@ -357,7 +371,7 @@ public class OrganizationUsersController : Controller
}
var userId = _userService.GetProperUserId(User);
var result = await _organizationService.ConfirmUserAsync(orgGuidId, new Guid(id), model.Key, userId.Value);
var result = await _confirmOrganizationUserCommand.ConfirmUserAsync(orgGuidId, new Guid(id), model.Key, userId.Value);
}
[HttpPost("confirm")]
@ -371,7 +385,7 @@ public class OrganizationUsersController : Controller
}
var userId = _userService.GetProperUserId(User);
var results = await _organizationService.ConfirmUsersAsync(orgGuidId, model.ToDictionary(), userId.Value);
var results = await _confirmOrganizationUserCommand.ConfirmUsersAsync(orgGuidId, model.ToDictionary(), userId.Value);
return new ListResponseModel<OrganizationUserBulkResponseModel>(results.Select(r =>
new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2)));
@ -577,7 +591,7 @@ public class OrganizationUsersController : Controller
throw new UnauthorizedAccessException();
}
await _deleteManagedOrganizationUserAccountCommand.DeleteUserAsync(orgId, id, currentUser.Id);
await _deleteClaimedOrganizationUserAccountCommand.DeleteUserAsync(orgId, id, currentUser.Id);
}
[RequireFeature(FeatureFlagKeys.AccountDeprovisioning)]
@ -596,7 +610,7 @@ public class OrganizationUsersController : Controller
throw new UnauthorizedAccessException();
}
var results = await _deleteManagedOrganizationUserAccountCommand.DeleteManyUsersAsync(orgId, model.Ids, currentUser.Id);
var results = await _deleteClaimedOrganizationUserAccountCommand.DeleteManyUsersAsync(orgId, model.Ids, currentUser.Id);
return new ListResponseModel<OrganizationUserBulkResponseModel>(results.Select(r =>
new OrganizationUserBulkResponseModel(r.OrganizationUserId, r.ErrorMessage)));
@ -620,14 +634,14 @@ public class OrganizationUsersController : Controller
[HttpPut("{id}/restore")]
public async Task RestoreAsync(Guid orgId, Guid id)
{
await RestoreOrRevokeUserAsync(orgId, id, (orgUser, userId) => _organizationService.RestoreUserAsync(orgUser, userId));
await RestoreOrRevokeUserAsync(orgId, id, (orgUser, userId) => _restoreOrganizationUserCommand.RestoreUserAsync(orgUser, userId));
}
[HttpPatch("restore")]
[HttpPut("restore")]
public async Task<ListResponseModel<OrganizationUserBulkResponseModel>> BulkRestoreAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model)
{
return await RestoreOrRevokeUsersAsync(orgId, model, (orgId, orgUserIds, restoringUserId) => _organizationService.RestoreUsersAsync(orgId, orgUserIds, restoringUserId, _userService));
return await RestoreOrRevokeUsersAsync(orgId, model, (orgId, orgUserIds, restoringUserId) => _restoreOrganizationUserCommand.RestoreUsersAsync(orgId, orgUserIds, restoringUserId, _userService));
}
[HttpPatch("enable-secrets-manager")]
@ -703,14 +717,14 @@ public class OrganizationUsersController : Controller
new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2)));
}
private async Task<IDictionary<Guid, bool>> GetManagedByOrganizationStatusAsync(Guid orgId, IEnumerable<Guid> userIds)
private async Task<IDictionary<Guid, bool>> GetClaimedByOrganizationStatusAsync(Guid orgId, IEnumerable<Guid> userIds)
{
if (!_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning))
{
return userIds.ToDictionary(kvp => kvp, kvp => false);
}
var usersOrganizationManagementStatus = await _getOrganizationUsersManagementStatusQuery.GetUsersOrganizationManagementStatusAsync(orgId, userIds);
return usersOrganizationManagementStatus;
var usersOrganizationClaimedStatus = await _getOrganizationUsersClaimedStatusQuery.GetUsersOrganizationClaimedStatusAsync(orgId, userIds);
return usersOrganizationClaimedStatus;
}
}

View File

@ -16,6 +16,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationApiKeys.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Repositories;
@ -61,7 +63,9 @@ public class OrganizationsController : Controller
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
private readonly ICloudOrganizationSignUpCommand _cloudOrganizationSignUpCommand;
private readonly IOrganizationDeleteCommand _organizationDeleteCommand;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IPricingClient _pricingClient;
private readonly IOrganizationUpdateKeysCommand _organizationUpdateKeysCommand;
public OrganizationsController(
IOrganizationRepository organizationRepository,
@ -84,7 +88,9 @@ public class OrganizationsController : Controller
IRemoveOrganizationUserCommand removeOrganizationUserCommand,
ICloudOrganizationSignUpCommand cloudOrganizationSignUpCommand,
IOrganizationDeleteCommand organizationDeleteCommand,
IPricingClient pricingClient)
IPolicyRequirementQuery policyRequirementQuery,
IPricingClient pricingClient,
IOrganizationUpdateKeysCommand organizationUpdateKeysCommand)
{
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
@ -106,7 +112,9 @@ public class OrganizationsController : Controller
_removeOrganizationUserCommand = removeOrganizationUserCommand;
_cloudOrganizationSignUpCommand = cloudOrganizationSignUpCommand;
_organizationDeleteCommand = organizationDeleteCommand;
_policyRequirementQuery = policyRequirementQuery;
_pricingClient = pricingClient;
_organizationUpdateKeysCommand = organizationUpdateKeysCommand;
}
[HttpGet("{id}")]
@ -135,10 +143,10 @@ public class OrganizationsController : Controller
var organizations = await _organizationUserRepository.GetManyDetailsByUserAsync(userId,
OrganizationUserStatusType.Confirmed);
var organizationManagingActiveUser = await _userService.GetOrganizationsManagingUserAsync(userId);
var organizationIdsManagingActiveUser = organizationManagingActiveUser.Select(o => o.Id);
var organizationsClaimingActiveUser = await _userService.GetOrganizationsClaimingUserAsync(userId);
var organizationIdsClaimingActiveUser = organizationsClaimingActiveUser.Select(o => o.Id);
var responses = organizations.Select(o => new ProfileOrganizationResponseModel(o, organizationIdsManagingActiveUser));
var responses = organizations.Select(o => new ProfileOrganizationResponseModel(o, organizationIdsClaimingActiveUser));
return new ListResponseModel<ProfileOrganizationResponseModel>(responses);
}
@ -163,8 +171,13 @@ public class OrganizationsController : Controller
throw new NotFoundException();
}
var resetPasswordPolicy =
await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword);
if (_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements))
{
var resetPasswordPolicyRequirement = await _policyRequirementQuery.GetAsync<ResetPasswordPolicyRequirement>(user.Id);
return new OrganizationAutoEnrollStatusResponseModel(organization.Id, resetPasswordPolicyRequirement.AutoEnrollEnabled(organization.Id));
}
var resetPasswordPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword);
if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled || resetPasswordPolicy.Data == null)
{
return new OrganizationAutoEnrollStatusResponseModel(organization.Id, false);
@ -172,6 +185,7 @@ public class OrganizationsController : Controller
var data = JsonSerializer.Deserialize<ResetPasswordDataModel>(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase);
return new OrganizationAutoEnrollStatusResponseModel(organization.Id, data?.AutoEnrollEnabled ?? false);
}
[HttpPost("")]
@ -266,9 +280,9 @@ public class OrganizationsController : Controller
}
if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning)
&& (await _userService.GetOrganizationsManagingUserAsync(user.Id)).Any(x => x.Id == id))
&& (await _userService.GetOrganizationsClaimingUserAsync(user.Id)).Any(x => x.Id == id))
{
throw new BadRequestException("Managed user account cannot leave managing organization. Contact your organization administrator for additional details.");
throw new BadRequestException("Claimed user account cannot leave claiming organization. Contact your organization administrator for additional details.");
}
await _removeOrganizationUserCommand.UserLeaveAsync(id, user.Id);
@ -479,7 +493,7 @@ public class OrganizationsController : Controller
}
[HttpPost("{id}/keys")]
public async Task<OrganizationKeysResponseModel> PostKeys(string id, [FromBody] OrganizationKeysRequestModel model)
public async Task<OrganizationKeysResponseModel> PostKeys(Guid id, [FromBody] OrganizationKeysRequestModel model)
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
@ -487,7 +501,7 @@ public class OrganizationsController : Controller
throw new UnauthorizedAccessException();
}
var org = await _organizationService.UpdateOrganizationKeysAsync(new Guid(id), model.PublicKey,
var org = await _organizationUpdateKeysCommand.UpdateOrganizationKeysAsync(id, model.PublicKey,
model.EncryptedPrivateKey);
return new OrganizationKeysResponseModel(org);
}

View File

@ -13,7 +13,17 @@ public static class PolicyDetailResponses
{
throw new ArgumentException($"'{nameof(policy)}' must be of type '{nameof(PolicyType.SingleOrg)}'.", nameof(policy));
}
return new PolicyDetailResponseModel(policy, await CanToggleState());
return new PolicyDetailResponseModel(policy, !await hasVerifiedDomainsQuery.HasVerifiedDomainsAsync(policy.OrganizationId));
async Task<bool> CanToggleState()
{
if (!await hasVerifiedDomainsQuery.HasVerifiedDomainsAsync(policy.OrganizationId))
{
return true;
}
return !policy.Enabled;
}
}
}

View File

@ -64,6 +64,7 @@ public class OrganizationResponseModel : ResponseModel
LimitItemDeletion = organization.LimitItemDeletion;
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems;
UseRiskInsights = organization.UseRiskInsights;
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies;
}
public Guid Id { get; set; }
@ -110,6 +111,7 @@ public class OrganizationResponseModel : ResponseModel
public bool LimitItemDeletion { get; set; }
public bool AllowAdminAccessToAllCollectionItems { get; set; }
public bool UseRiskInsights { get; set; }
public bool UseAdminSponsoredFamilies { get; set; }
}
public class OrganizationSubscriptionResponseModel : OrganizationResponseModel

View File

@ -66,24 +66,34 @@ public class OrganizationUserDetailsResponseModel : OrganizationUserResponseMode
{
public OrganizationUserDetailsResponseModel(
OrganizationUser organizationUser,
bool managedByOrganization,
bool claimedByOrganization,
string ssoExternalId,
IEnumerable<CollectionAccessSelection> collections)
: base(organizationUser, "organizationUserDetails")
{
ManagedByOrganization = managedByOrganization;
ClaimedByOrganization = claimedByOrganization;
SsoExternalId = ssoExternalId;
Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c));
}
public OrganizationUserDetailsResponseModel(OrganizationUserUserDetails organizationUser,
bool managedByOrganization,
bool claimedByOrganization,
IEnumerable<CollectionAccessSelection> collections)
: base(organizationUser, "organizationUserDetails")
{
ManagedByOrganization = managedByOrganization;
ClaimedByOrganization = claimedByOrganization;
SsoExternalId = organizationUser.SsoExternalId;
Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c));
}
public bool ManagedByOrganization { get; set; }
[Obsolete("Please use ClaimedByOrganization instead. This property will be removed in a future version.")]
public bool ManagedByOrganization
{
get => ClaimedByOrganization;
set => ClaimedByOrganization = value;
}
public bool ClaimedByOrganization { get; set; }
public string SsoExternalId { get; set; }
public IEnumerable<SelectionReadOnlyResponseModel> Collections { get; set; }
@ -117,7 +127,7 @@ public class OrganizationUserUserMiniDetailsResponseModel : ResponseModel
public class OrganizationUserUserDetailsResponseModel : OrganizationUserResponseModel
{
public OrganizationUserUserDetailsResponseModel(OrganizationUserUserDetails organizationUser,
bool twoFactorEnabled, bool managedByOrganization, string obj = "organizationUserUserDetails")
bool twoFactorEnabled, bool claimedByOrganization, string obj = "organizationUserUserDetails")
: base(organizationUser, obj)
{
if (organizationUser == null)
@ -134,7 +144,7 @@ public class OrganizationUserUserDetailsResponseModel : OrganizationUserResponse
Groups = organizationUser.Groups;
// Prevent reset password when using key connector.
ResetPasswordEnrolled = ResetPasswordEnrolled && !organizationUser.UsesKeyConnector;
ManagedByOrganization = managedByOrganization;
ClaimedByOrganization = claimedByOrganization;
}
public string Name { get; set; }
@ -142,11 +152,17 @@ public class OrganizationUserUserDetailsResponseModel : OrganizationUserResponse
public string AvatarColor { get; set; }
public bool TwoFactorEnabled { get; set; }
public bool SsoBound { get; set; }
[Obsolete("Please use ClaimedByOrganization instead. This property will be removed in a future version.")]
public bool ManagedByOrganization
{
get => ClaimedByOrganization;
set => ClaimedByOrganization = value;
}
/// <summary>
/// Indicates if the organization manages the user. If a user is "managed" by an organization,
/// Indicates if the organization claimed the user. If a user is "claimed" by an organization,
/// the organization has greater control over their account, and some user actions are restricted.
/// </summary>
public bool ManagedByOrganization { get; set; }
public bool ClaimedByOrganization { get; set; }
public IEnumerable<SelectionReadOnlyResponseModel> Collections { get; set; }
public IEnumerable<Guid> Groups { get; set; }
}

View File

@ -18,7 +18,7 @@ public class ProfileOrganizationResponseModel : ResponseModel
public ProfileOrganizationResponseModel(
OrganizationUserOrganizationDetails organization,
IEnumerable<Guid> organizationIdsManagingUser)
IEnumerable<Guid> organizationIdsClaimingUser)
: this("profileOrganization")
{
Id = organization.OrganizationId;
@ -51,7 +51,7 @@ public class ProfileOrganizationResponseModel : ResponseModel
SsoBound = !string.IsNullOrWhiteSpace(organization.SsoExternalId);
Identifier = organization.Identifier;
Permissions = CoreHelpers.LoadClassFromJsonData<Permissions>(organization.Permissions);
ResetPasswordEnrolled = organization.ResetPasswordKey != null;
ResetPasswordEnrolled = !string.IsNullOrWhiteSpace(organization.ResetPasswordKey);
UserId = organization.UserId;
OrganizationUserId = organization.OrganizationUserId;
ProviderId = organization.ProviderId;
@ -70,8 +70,9 @@ public class ProfileOrganizationResponseModel : ResponseModel
LimitCollectionDeletion = organization.LimitCollectionDeletion;
LimitItemDeletion = organization.LimitItemDeletion;
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems;
UserIsManagedByOrganization = organizationIdsManagingUser.Contains(organization.OrganizationId);
UserIsClaimedByOrganization = organizationIdsClaimingUser.Contains(organization.OrganizationId);
UseRiskInsights = organization.UseRiskInsights;
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies;
if (organization.SsoConfig != null)
{
@ -133,15 +134,27 @@ public class ProfileOrganizationResponseModel : ResponseModel
public bool LimitItemDeletion { get; set; }
public bool AllowAdminAccessToAllCollectionItems { get; set; }
/// <summary>
/// Indicates if the organization manages the user.
/// Obsolete.
///
/// See <see cref="UserIsClaimedByOrganization"/>
/// </summary>
[Obsolete("Please use UserIsClaimedByOrganization instead. This property will be removed in a future version.")]
public bool UserIsManagedByOrganization
{
get => UserIsClaimedByOrganization;
set => UserIsClaimedByOrganization = value;
}
/// <summary>
/// Indicates if the organization claims the user.
/// </summary>
/// <remarks>
/// An organization manages a user if the user's email domain is verified by the organization and the user is a member of it.
/// An organization claims a user if the user's email domain is verified by the organization and the user is a member of it.
/// The organization must be enabled and able to have verified domains.
/// </remarks>
/// <returns>
/// False if the Account Deprovisioning feature flag is disabled.
/// </returns>
public bool UserIsManagedByOrganization { get; set; }
public bool UserIsClaimedByOrganization { get; set; }
public bool UseRiskInsights { get; set; }
public bool UseAdminSponsoredFamilies { get; set; }
}

View File

@ -50,5 +50,6 @@ public class ProfileProviderOrganizationResponseModel : ProfileOrganizationRespo
LimitItemDeletion = organization.LimitItemDeletion;
AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems;
UseRiskInsights = organization.UseRiskInsights;
UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies;
}
}

View File

@ -22,6 +22,7 @@ public class ProfileProviderResponseModel : ResponseModel
UserId = provider.UserId;
UseEvents = provider.UseEvents;
ProviderStatus = provider.ProviderStatus;
ProviderType = provider.ProviderType;
}
public Guid Id { get; set; }
@ -35,4 +36,5 @@ public class ProfileProviderResponseModel : ResponseModel
public Guid? UserId { get; set; }
public bool UseEvents { get; set; }
public ProviderStatusType ProviderStatus { get; set; }
public ProviderType ProviderType { get; set; }
}

View File

@ -30,7 +30,7 @@ public class MemberResponseModel : MemberBaseModel, IResponseModel
Email = user.Email;
Status = user.Status;
Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c));
ResetPasswordEnrolled = user.ResetPasswordKey != null;
ResetPasswordEnrolled = !string.IsNullOrWhiteSpace(user.ResetPasswordKey);
}
[SetsRequiredMembers]
@ -49,7 +49,7 @@ public class MemberResponseModel : MemberBaseModel, IResponseModel
TwoFactorEnabled = twoFactorEnabled;
Status = user.Status;
Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c));
ResetPasswordEnrolled = user.ResetPasswordKey != null;
ResetPasswordEnrolled = !string.IsNullOrWhiteSpace(user.ResetPasswordKey);
SsoExternalId = user.SsoExternalId;
}

View File

@ -124,11 +124,11 @@ public class AccountsController : Controller
throw new BadRequestException("MasterPasswordHash", "Invalid password.");
}
var managedUserValidationResult = await _userService.ValidateManagedUserDomainAsync(user, model.NewEmail);
var claimedUserValidationResult = await _userService.ValidateClaimedUserDomainAsync(user, model.NewEmail);
if (!managedUserValidationResult.Succeeded)
if (!claimedUserValidationResult.Succeeded)
{
throw new BadRequestException(managedUserValidationResult.Errors);
throw new BadRequestException(claimedUserValidationResult.Errors);
}
await _userService.InitiateEmailChangeAsync(user, model.NewEmail);
@ -355,6 +355,7 @@ public class AccountsController : Controller
throw new BadRequestException(ModelState);
}
[Obsolete("Replaced by the safer rotate-user-account-keys endpoint.")]
[HttpPost("key")]
public async Task PostKey([FromBody] UpdateKeyRequestModel model)
{
@ -436,11 +437,11 @@ public class AccountsController : Controller
var twoFactorEnabled = await _userService.TwoFactorIsEnabledAsync(user);
var hasPremiumFromOrg = await _userService.HasPremiumFromOrganization(user);
var organizationIdsManagingActiveUser = await GetOrganizationIdsManagingUserAsync(user.Id);
var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id);
var response = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails,
providerUserOrganizationDetails, twoFactorEnabled,
hasPremiumFromOrg, organizationIdsManagingActiveUser);
hasPremiumFromOrg, organizationIdsClaimingActiveUser);
return response;
}
@ -450,9 +451,9 @@ public class AccountsController : Controller
var userId = _userService.GetProperUserId(User);
var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(userId.Value,
OrganizationUserStatusType.Confirmed);
var organizationIdsManagingActiveUser = await GetOrganizationIdsManagingUserAsync(userId.Value);
var organizationIdsClaimingUser = await GetOrganizationIdsClaimingUserAsync(userId.Value);
var responseData = organizationUserDetails.Select(o => new ProfileOrganizationResponseModel(o, organizationIdsManagingActiveUser));
var responseData = organizationUserDetails.Select(o => new ProfileOrganizationResponseModel(o, organizationIdsClaimingUser));
return new ListResponseModel<ProfileOrganizationResponseModel>(responseData);
}
@ -470,9 +471,9 @@ public class AccountsController : Controller
var twoFactorEnabled = await _userService.TwoFactorIsEnabledAsync(user);
var hasPremiumFromOrg = await _userService.HasPremiumFromOrganization(user);
var organizationIdsManagingActiveUser = await GetOrganizationIdsManagingUserAsync(user.Id);
var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id);
var response = new ProfileResponseModel(user, null, null, null, twoFactorEnabled, hasPremiumFromOrg, organizationIdsManagingActiveUser);
var response = new ProfileResponseModel(user, null, null, null, twoFactorEnabled, hasPremiumFromOrg, organizationIdsClaimingActiveUser);
return response;
}
@ -489,9 +490,9 @@ public class AccountsController : Controller
var userTwoFactorEnabled = await _userService.TwoFactorIsEnabledAsync(user);
var userHasPremiumFromOrganization = await _userService.HasPremiumFromOrganization(user);
var organizationIdsManagingActiveUser = await GetOrganizationIdsManagingUserAsync(user.Id);
var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id);
var response = new ProfileResponseModel(user, null, null, null, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsManagingActiveUser);
var response = new ProfileResponseModel(user, null, null, null, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsClaimingActiveUser);
return response;
}
@ -559,9 +560,9 @@ public class AccountsController : Controller
}
else
{
// If Account Deprovisioning is enabled, we need to check if the user is managed by any organization.
// If Account Deprovisioning is enabled, we need to check if the user is claimed by any organization.
if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning)
&& await _userService.IsManagedByAnyOrganizationAsync(user.Id))
&& await _userService.IsClaimedByAnyOrganizationAsync(user.Id))
{
throw new BadRequestException("Cannot delete accounts owned by an organization. Contact your organization administrator for additional details.");
}
@ -762,9 +763,9 @@ public class AccountsController : Controller
await _userService.SaveUserAsync(user);
}
private async Task<IEnumerable<Guid>> GetOrganizationIdsManagingUserAsync(Guid userId)
private async Task<IEnumerable<Guid>> GetOrganizationIdsClaimingUserAsync(Guid userId)
{
var organizationManagingUser = await _userService.GetOrganizationsManagingUserAsync(userId);
return organizationManagingUser.Select(o => o.Id);
var organizationsClaimingUser = await _userService.GetOrganizationsClaimingUserAsync(userId);
return organizationsClaimingUser.Select(o => o.Id);
}
}

View File

@ -0,0 +1,66 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Core.Enums;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.Utilities;
namespace Bit.Api.Auth.Models.Request.Accounts;
public class MasterPasswordUnlockDataModel : IValidatableObject
{
public required KdfType KdfType { get; set; }
public required int KdfIterations { get; set; }
public int? KdfMemory { get; set; }
public int? KdfParallelism { get; set; }
[StrictEmailAddress]
[StringLength(256)]
public required string Email { get; set; }
[StringLength(300)]
public required string MasterKeyAuthenticationHash { get; set; }
[EncryptedString] public required string MasterKeyEncryptedUserKey { get; set; }
[StringLength(50)]
public string? MasterPasswordHint { get; set; }
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
{
if (KdfType == KdfType.PBKDF2_SHA256)
{
if (KdfMemory.HasValue || KdfParallelism.HasValue)
{
yield return new ValidationResult("KdfMemory and KdfParallelism must be null for PBKDF2_SHA256", new[] { nameof(KdfMemory), nameof(KdfParallelism) });
}
}
else if (KdfType == KdfType.Argon2id)
{
if (!KdfMemory.HasValue || !KdfParallelism.HasValue)
{
yield return new ValidationResult("KdfMemory and KdfParallelism must have values for Argon2id", new[] { nameof(KdfMemory), nameof(KdfParallelism) });
}
}
else
{
yield return new ValidationResult("Invalid KdfType", new[] { nameof(KdfType) });
}
}
public MasterPasswordUnlockData ToUnlockData()
{
var data = new MasterPasswordUnlockData
{
KdfType = KdfType,
KdfIterations = KdfIterations,
KdfMemory = KdfMemory,
KdfParallelism = KdfParallelism,
Email = Email,
MasterKeyAuthenticationHash = MasterKeyAuthenticationHash,
MasterKeyEncryptedUserKey = MasterKeyEncryptedUserKey,
MasterPasswordHint = MasterPasswordHint
};
return data;
}
}

View File

@ -0,0 +1,11 @@
using System.ComponentModel.DataAnnotations;
#nullable enable
namespace Bit.Api.Auth.Models.Request;
public class UntrustDevicesRequestModel
{
[Required]
public IEnumerable<Guid> Devices { get; set; } = null!;
}

View File

@ -58,10 +58,10 @@ public class AccountsController(
var userTwoFactorEnabled = await userService.TwoFactorIsEnabledAsync(user);
var userHasPremiumFromOrganization = await userService.HasPremiumFromOrganization(user);
var organizationIdsManagingActiveUser = await GetOrganizationIdsManagingUserAsync(user.Id);
var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id);
var profile = new ProfileResponseModel(user, null, null, null, userTwoFactorEnabled,
userHasPremiumFromOrganization, organizationIdsManagingActiveUser);
userHasPremiumFromOrganization, organizationIdsClaimingActiveUser);
return new PaymentResponseModel
{
UserProfile = profile,
@ -229,9 +229,9 @@ public class AccountsController(
await paymentService.SaveTaxInfoAsync(user, taxInfo);
}
private async Task<IEnumerable<Guid>> GetOrganizationIdsManagingUserAsync(Guid userId)
private async Task<IEnumerable<Guid>> GetOrganizationIdsClaimingUserAsync(Guid userId)
{
var organizationManagingUser = await userService.GetOrganizationsManagingUserAsync(userId);
return organizationManagingUser.Select(o => o.Id);
var organizationsClaimingUser = await userService.GetOrganizationsClaimingUserAsync(userId);
return organizationsClaimingUser.Select(o => o.Id);
}
}

View File

@ -2,6 +2,7 @@
using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses;
using Bit.Core;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Pricing;
@ -18,7 +19,9 @@ namespace Bit.Api.Billing.Controllers;
[Route("organizations/{organizationId:guid}/billing")]
[Authorize("Application")]
public class OrganizationBillingController(
IBusinessUnitConverter businessUnitConverter,
ICurrentContext currentContext,
IFeatureService featureService,
IOrganizationBillingService organizationBillingService,
IOrganizationRepository organizationRepository,
IPaymentService paymentService,
@ -296,4 +299,40 @@ public class OrganizationBillingController(
return TypedResults.Ok();
}
[HttpPost("setup-business-unit")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IResult> SetupBusinessUnitAsync(
[FromRoute] Guid organizationId,
[FromBody] SetupBusinessUnitRequestBody requestBody)
{
var enableOrganizationBusinessUnitConversion =
featureService.IsEnabled(FeatureFlagKeys.PM18770_EnableOrganizationBusinessUnitConversion);
if (!enableOrganizationBusinessUnitConversion)
{
return Error.NotFound();
}
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
return Error.NotFound();
}
if (!await currentContext.OrganizationUser(organizationId))
{
return Error.Unauthorized();
}
var providerId = await businessUnitConverter.FinalizeConversion(
organization,
requestBody.UserId,
requestBody.Token,
requestBody.ProviderKey,
requestBody.OrganizationKey);
return TypedResults.Ok(providerId);
}
}

View File

@ -76,11 +76,35 @@ public class OrganizationSponsorshipsController : Controller
public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model)
{
var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId);
var freeFamiliesSponsorshipPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(sponsoringOrgId,
PolicyType.FreeFamiliesSponsorshipPolicy);
if (freeFamiliesSponsorshipPolicy?.Enabled == true)
{
throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator.");
}
if (!_featureService.IsEnabled(Bit.Core.FeatureFlagKeys.PM17772_AdminInitiatedSponsorships))
{
if (model.SponsoringUserId.HasValue)
{
throw new NotFoundException();
}
if (!string.IsNullOrWhiteSpace(model.Notes))
{
model.Notes = null;
}
}
var targetUser = model.SponsoringUserId ?? _currentContext.UserId!.Value;
var sponsorship = await _createSponsorshipCommand.CreateSponsorshipAsync(
sponsoringOrg,
await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default),
model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName);
await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, targetUser),
model.PlanSponsorshipType,
model.SponsoredEmail,
model.FriendlyName,
model.Notes);
await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name);
}
@ -89,6 +113,14 @@ public class OrganizationSponsorshipsController : Controller
[SelfHosted(NotSelfHostedOnly = true)]
public async Task ResendSponsorshipOffer(Guid sponsoringOrgId)
{
var freeFamiliesSponsorshipPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(sponsoringOrgId,
PolicyType.FreeFamiliesSponsorshipPolicy);
if (freeFamiliesSponsorshipPolicy?.Enabled == true)
{
throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator.");
}
var sponsoringOrgUser = await _organizationUserRepository
.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default);
@ -135,6 +167,14 @@ public class OrganizationSponsorshipsController : Controller
throw new BadRequestException("Can only redeem sponsorship for an organization you own.");
}
var freeFamiliesSponsorshipPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(
model.SponsoredOrganizationId, PolicyType.FreeFamiliesSponsorshipPolicy);
if (freeFamiliesSponsorshipPolicy?.Enabled == true)
{
throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator.");
}
await _setUpSponsorshipCommand.SetUpSponsorshipAsync(
sponsorship,
await _organizationRepository.GetByIdAsync(model.SponsoredOrganizationId));

View File

@ -409,9 +409,9 @@ public class OrganizationsController(
organizationId,
OrganizationUserStatusType.Confirmed);
var organizationIdsManagingActiveUser = (await userService.GetOrganizationsManagingUserAsync(userId))
var organizationIdsClaimingActiveUser = (await userService.GetOrganizationsClaimingUserAsync(userId))
.Select(o => o.Id);
return new ProfileOrganizationResponseModel(organizationUserDetails, organizationIdsManagingActiveUser);
return new ProfileOrganizationResponseModel(organizationUserDetails, organizationIdsClaimingActiveUser);
}
}

View File

@ -0,0 +1,18 @@
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Billing.Models.Requests;
public class SetupBusinessUnitRequestBody
{
[Required]
public Guid UserId { get; set; }
[Required]
public string Token { get; set; }
[Required]
public string ProviderKey { get; set; }
[Required]
public string OrganizationKey { get; set; }
}

View File

@ -1,10 +1,10 @@
using System.ComponentModel.DataAnnotations;
using Bit.Api.Auth.Models.Request;
using Bit.Api.Auth.Models.Request.Accounts;
using Bit.Api.Models.Request;
using Bit.Api.Models.Response;
using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Auth.Models.Api.Response;
using Bit.Core.Auth.UserFeatures.DeviceTrust;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
@ -22,6 +22,7 @@ public class DevicesController : Controller
private readonly IDeviceRepository _deviceRepository;
private readonly IDeviceService _deviceService;
private readonly IUserService _userService;
private readonly IUntrustDevicesCommand _untrustDevicesCommand;
private readonly IUserRepository _userRepository;
private readonly ICurrentContext _currentContext;
private readonly ILogger<DevicesController> _logger;
@ -30,6 +31,7 @@ public class DevicesController : Controller
IDeviceRepository deviceRepository,
IDeviceService deviceService,
IUserService userService,
IUntrustDevicesCommand untrustDevicesCommand,
IUserRepository userRepository,
ICurrentContext currentContext,
ILogger<DevicesController> logger)
@ -37,6 +39,7 @@ public class DevicesController : Controller
_deviceRepository = deviceRepository;
_deviceService = deviceService;
_userService = userService;
_untrustDevicesCommand = untrustDevicesCommand;
_userRepository = userRepository;
_currentContext = currentContext;
_logger = logger;
@ -125,7 +128,7 @@ public class DevicesController : Controller
}
[HttpPost("{identifier}/retrieve-keys")]
public async Task<ProtectedDeviceResponseModel> GetDeviceKeys(string identifier, [FromBody] SecretVerificationRequestModel model)
public async Task<ProtectedDeviceResponseModel> GetDeviceKeys(string identifier)
{
var user = await _userService.GetUserByPrincipalAsync(User);
@ -134,14 +137,7 @@ public class DevicesController : Controller
throw new UnauthorizedAccessException();
}
if (!await _userService.VerifySecretAsync(user, model.Secret))
{
await Task.Delay(2000);
throw new BadRequestException(string.Empty, "User verification failed.");
}
var device = await _deviceRepository.GetByIdentifierAsync(identifier, user.Id);
if (device == null)
{
throw new NotFoundException();
@ -173,6 +169,19 @@ public class DevicesController : Controller
model.OtherDevices ?? Enumerable.Empty<OtherDeviceKeysUpdateRequestModel>());
}
[HttpPost("untrust")]
public async Task PostUntrust([FromBody] UntrustDevicesRequestModel model)
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
await _untrustDevicesCommand.UntrustDevices(user, model.Devices);
}
[HttpPut("identifier/{identifier}/token")]
[HttpPost("identifier/{identifier}/token")]
public async Task PutToken(string identifier, [FromBody] DeviceTokenRequestModel model)

View File

@ -3,6 +3,7 @@ using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
@ -20,6 +21,7 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
private readonly ICreateSponsorshipCommand _offerSponsorshipCommand;
private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand;
private readonly ICurrentContext _currentContext;
private readonly IFeatureService _featureService;
public SelfHostedOrganizationSponsorshipsController(
ICreateSponsorshipCommand offerSponsorshipCommand,
@ -27,7 +29,8 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
IOrganizationRepository organizationRepository,
IOrganizationSponsorshipRepository organizationSponsorshipRepository,
IOrganizationUserRepository organizationUserRepository,
ICurrentContext currentContext
ICurrentContext currentContext,
IFeatureService featureService
)
{
_offerSponsorshipCommand = offerSponsorshipCommand;
@ -36,15 +39,29 @@ public class SelfHostedOrganizationSponsorshipsController : Controller
_organizationSponsorshipRepository = organizationSponsorshipRepository;
_organizationUserRepository = organizationUserRepository;
_currentContext = currentContext;
_featureService = featureService;
}
[HttpPost("{sponsoringOrgId}/families-for-enterprise")]
public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model)
{
if (!_featureService.IsEnabled(Bit.Core.FeatureFlagKeys.PM17772_AdminInitiatedSponsorships))
{
if (model.SponsoringUserId.HasValue)
{
throw new NotFoundException();
}
if (!string.IsNullOrWhiteSpace(model.Notes))
{
model.Notes = null;
}
}
await _offerSponsorshipCommand.CreateSponsorshipAsync(
await _organizationRepository.GetByIdAsync(sponsoringOrgId),
await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default),
model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName);
await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, model.SponsoringUserId ?? _currentContext.UserId ?? default),
model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName, model.Notes);
}
[HttpDelete("{sponsoringOrgId}")]

View File

@ -1,10 +1,24 @@
#nullable enable
using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.Auth.Models.Request;
using Bit.Api.Auth.Models.Request.WebAuthn;
using Bit.Api.KeyManagement.Models.Requests;
using Bit.Api.KeyManagement.Validators;
using Bit.Api.Tools.Models.Request;
using Bit.Api.Vault.Models.Request;
using Bit.Core;
using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Auth.Models.Data;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.KeyManagement.Commands.Interfaces;
using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Tools.Entities;
using Bit.Core.Vault.Entities;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
@ -19,18 +33,48 @@ public class AccountsKeyManagementController : Controller
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IRegenerateUserAsymmetricKeysCommand _regenerateUserAsymmetricKeysCommand;
private readonly IUserService _userService;
private readonly IRotateUserAccountKeysCommand _rotateUserAccountKeysCommand;
private readonly IRotationValidator<IEnumerable<CipherWithIdRequestModel>, IEnumerable<Cipher>> _cipherValidator;
private readonly IRotationValidator<IEnumerable<FolderWithIdRequestModel>, IEnumerable<Folder>> _folderValidator;
private readonly IRotationValidator<IEnumerable<SendWithIdRequestModel>, IReadOnlyList<Send>> _sendValidator;
private readonly IRotationValidator<IEnumerable<EmergencyAccessWithIdRequestModel>, IEnumerable<EmergencyAccess>>
_emergencyAccessValidator;
private readonly IRotationValidator<IEnumerable<ResetPasswordWithOrgIdRequestModel>,
IReadOnlyList<OrganizationUser>>
_organizationUserValidator;
private readonly IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>>
_webauthnKeyValidator;
private readonly IRotationValidator<IEnumerable<OtherDeviceKeysUpdateRequestModel>, IEnumerable<Device>> _deviceValidator;
public AccountsKeyManagementController(IUserService userService,
IFeatureService featureService,
IOrganizationUserRepository organizationUserRepository,
IEmergencyAccessRepository emergencyAccessRepository,
IRegenerateUserAsymmetricKeysCommand regenerateUserAsymmetricKeysCommand)
IRegenerateUserAsymmetricKeysCommand regenerateUserAsymmetricKeysCommand,
IRotateUserAccountKeysCommand rotateUserKeyCommandV2,
IRotationValidator<IEnumerable<CipherWithIdRequestModel>, IEnumerable<Cipher>> cipherValidator,
IRotationValidator<IEnumerable<FolderWithIdRequestModel>, IEnumerable<Folder>> folderValidator,
IRotationValidator<IEnumerable<SendWithIdRequestModel>, IReadOnlyList<Send>> sendValidator,
IRotationValidator<IEnumerable<EmergencyAccessWithIdRequestModel>, IEnumerable<EmergencyAccess>>
emergencyAccessValidator,
IRotationValidator<IEnumerable<ResetPasswordWithOrgIdRequestModel>, IReadOnlyList<OrganizationUser>>
organizationUserValidator,
IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>> webAuthnKeyValidator,
IRotationValidator<IEnumerable<OtherDeviceKeysUpdateRequestModel>, IEnumerable<Device>> deviceValidator)
{
_userService = userService;
_featureService = featureService;
_regenerateUserAsymmetricKeysCommand = regenerateUserAsymmetricKeysCommand;
_organizationUserRepository = organizationUserRepository;
_emergencyAccessRepository = emergencyAccessRepository;
_rotateUserAccountKeysCommand = rotateUserKeyCommandV2;
_cipherValidator = cipherValidator;
_folderValidator = folderValidator;
_sendValidator = sendValidator;
_emergencyAccessValidator = emergencyAccessValidator;
_organizationUserValidator = organizationUserValidator;
_webauthnKeyValidator = webAuthnKeyValidator;
_deviceValidator = deviceValidator;
}
[HttpPost("regenerate-keys")]
@ -47,4 +91,46 @@ public class AccountsKeyManagementController : Controller
await _regenerateUserAsymmetricKeysCommand.RegenerateKeysAsync(request.ToUserAsymmetricKeys(user.Id),
usersOrganizationAccounts, designatedEmergencyAccess);
}
[HttpPost("rotate-user-account-keys")]
public async Task RotateUserAccountKeysAsync([FromBody] RotateUserAccountKeysAndDataRequestModel model)
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var dataModel = new RotateUserAccountKeysData
{
OldMasterKeyAuthenticationHash = model.OldMasterKeyAuthenticationHash,
UserKeyEncryptedAccountPrivateKey = model.AccountKeys.UserKeyEncryptedAccountPrivateKey,
AccountPublicKey = model.AccountKeys.AccountPublicKey,
MasterPasswordUnlockData = model.AccountUnlockData.MasterPasswordUnlockData.ToUnlockData(),
EmergencyAccesses = await _emergencyAccessValidator.ValidateAsync(user, model.AccountUnlockData.EmergencyAccessUnlockData),
OrganizationUsers = await _organizationUserValidator.ValidateAsync(user, model.AccountUnlockData.OrganizationAccountRecoveryUnlockData),
WebAuthnKeys = await _webauthnKeyValidator.ValidateAsync(user, model.AccountUnlockData.PasskeyUnlockData),
DeviceKeys = await _deviceValidator.ValidateAsync(user, model.AccountUnlockData.DeviceKeyUnlockData),
Ciphers = await _cipherValidator.ValidateAsync(user, model.AccountData.Ciphers),
Folders = await _folderValidator.ValidateAsync(user, model.AccountData.Folders),
Sends = await _sendValidator.ValidateAsync(user, model.AccountData.Sends),
};
var result = await _rotateUserAccountKeysCommand.RotateUserAccountKeysAsync(user, dataModel);
if (result.Succeeded)
{
return;
}
foreach (var error in result.Errors)
{
ModelState.AddModelError(string.Empty, error.Description);
}
throw new BadRequestException(ModelState);
}
}

View File

@ -0,0 +1,10 @@
#nullable enable
using Bit.Core.Utilities;
namespace Bit.Api.KeyManagement.Models.Requests;
public class AccountKeysRequestModel
{
[EncryptedString] public required string UserKeyEncryptedAccountPrivateKey { get; set; }
public required string AccountPublicKey { get; set; }
}

View File

@ -0,0 +1,13 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.KeyManagement.Models.Requests;
public class RotateUserAccountKeysAndDataRequestModel
{
[StringLength(300)]
public required string OldMasterKeyAuthenticationHash { get; set; }
public required UnlockDataRequestModel AccountUnlockData { get; set; }
public required AccountKeysRequestModel AccountKeys { get; set; }
public required AccountDataRequestModel AccountData { get; set; }
}

View File

@ -0,0 +1,18 @@
#nullable enable
using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.Auth.Models.Request;
using Bit.Api.Auth.Models.Request.Accounts;
using Bit.Api.Auth.Models.Request.WebAuthn;
using Bit.Core.Auth.Models.Api.Request;
namespace Bit.Api.KeyManagement.Models.Requests;
public class UnlockDataRequestModel
{
// All methods to get to the userkey
public required MasterPasswordUnlockDataModel MasterPasswordUnlockData { get; set; }
public required IEnumerable<EmergencyAccessWithIdRequestModel> EmergencyAccessUnlockData { get; set; }
public required IEnumerable<ResetPasswordWithOrgIdRequestModel> OrganizationAccountRecoveryUnlockData { get; set; }
public required IEnumerable<WebAuthnLoginRotateKeyRequestModel> PasskeyUnlockData { get; set; }
public required IEnumerable<OtherDeviceKeysUpdateRequestModel> DeviceKeyUnlockData { get; set; }
}

View File

@ -0,0 +1,12 @@
#nullable enable
using Bit.Api.Tools.Models.Request;
using Bit.Api.Vault.Models.Request;
namespace Bit.Api.KeyManagement.Models.Requests;
public class AccountDataRequestModel
{
public required IEnumerable<CipherWithIdRequestModel> Ciphers { get; set; }
public required IEnumerable<FolderWithIdRequestModel> Folders { get; set; }
public required IEnumerable<SendWithIdRequestModel> Sends { get; set; }
}

View File

@ -0,0 +1,53 @@
using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Auth.Utilities;
using Bit.Core.Entities;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
namespace Bit.Api.KeyManagement.Validators;
/// <summary>
/// Device implementation for <see cref="IRotationValidator{T,R}"/>
/// </summary>
public class DeviceRotationValidator : IRotationValidator<IEnumerable<OtherDeviceKeysUpdateRequestModel>, IEnumerable<Device>>
{
private readonly IDeviceRepository _deviceRepository;
/// <summary>
/// Instantiates a new <see cref="DeviceRotationValidator"/>
/// </summary>
/// <param name="deviceRepository">Retrieves all user <see cref="Device"/>s</param>
public DeviceRotationValidator(IDeviceRepository deviceRepository)
{
_deviceRepository = deviceRepository;
}
public async Task<IEnumerable<Device>> ValidateAsync(User user, IEnumerable<OtherDeviceKeysUpdateRequestModel> devices)
{
var result = new List<Device>();
var existingTrustedDevices = (await _deviceRepository.GetManyByUserIdAsync(user.Id)).Where(d => d.IsTrusted()).ToList();
if (existingTrustedDevices.Count == 0)
{
return result;
}
foreach (var existing in existingTrustedDevices)
{
var device = devices.FirstOrDefault(c => c.DeviceId == existing.Id);
if (device == null)
{
throw new BadRequestException("All existing trusted devices must be included in the rotation.");
}
if (device.EncryptedUserKey == null || device.EncryptedPublicKey == null)
{
throw new BadRequestException("Rotated encryption keys must be provided for all devices that are trusted.");
}
result.Add(device.ToDevice(existing));
}
return result;
}
}

View File

@ -17,20 +17,20 @@ public class WebAuthnLoginKeyRotationValidator : IRotationValidator<IEnumerable<
public async Task<IEnumerable<WebAuthnLoginRotateKeyData>> ValidateAsync(User user, IEnumerable<WebAuthnLoginRotateKeyRequestModel> keysToRotate)
{
// 2024-06: Remove after 3 releases, for backward compatibility
if (keysToRotate == null)
{
return new List<WebAuthnLoginRotateKeyData>();
}
var result = new List<WebAuthnLoginRotateKeyData>();
var existing = await _webAuthnCredentialRepository.GetManyByUserIdAsync(user.Id);
if (existing == null || !existing.Any())
if (existing == null)
{
return result;
}
foreach (var ea in existing)
var validCredentials = existing.Where(credential => credential.SupportsPrf);
if (!validCredentials.Any())
{
return result;
}
foreach (var ea in validCredentials)
{
var keyToRotate = keysToRotate.FirstOrDefault(c => c.Id == ea.Id);
if (keyToRotate == null)

View File

@ -16,4 +16,14 @@ public class OrganizationSponsorshipCreateRequestModel
[StringLength(256)]
public string FriendlyName { get; set; }
/// <summary>
/// (optional) The user to target for the sponsorship.
/// </summary>
/// <remarks>Left empty when creating a sponsorship for the authenticated user.</remarks>
public Guid? SponsoringUserId { get; set; }
[EncryptedString]
[EncryptedStringLength(512)]
public string Notes { get; set; }
}

View File

@ -15,7 +15,7 @@ public class ProfileResponseModel : ResponseModel
IEnumerable<ProviderUserOrganizationDetails> providerUserOrganizationDetails,
bool twoFactorEnabled,
bool premiumFromOrganization,
IEnumerable<Guid> organizationIdsManagingUser) : base("profile")
IEnumerable<Guid> organizationIdsClaimingUser) : base("profile")
{
if (user == null)
{
@ -38,7 +38,7 @@ public class ProfileResponseModel : ResponseModel
AvatarColor = user.AvatarColor;
CreationDate = user.CreationDate;
VerifyDevices = user.VerifyDevices;
Organizations = organizationsUserDetails?.Select(o => new ProfileOrganizationResponseModel(o, organizationIdsManagingUser));
Organizations = organizationsUserDetails?.Select(o => new ProfileOrganizationResponseModel(o, organizationIdsClaimingUser));
Providers = providerUserDetails?.Select(p => new ProfileProviderResponseModel(p));
ProviderOrganizations =
providerUserOrganizationDetails?.Select(po => new ProfileProviderOrganizationResponseModel(po));

View File

@ -22,6 +22,7 @@ public class NotificationResponseModel : ResponseModel
Title = notificationStatusDetails.Title;
Body = notificationStatusDetails.Body;
Date = notificationStatusDetails.RevisionDate;
TaskId = notificationStatusDetails.TaskId;
ReadDate = notificationStatusDetails.ReadDate;
DeletedDate = notificationStatusDetails.DeletedDate;
}
@ -40,6 +41,8 @@ public class NotificationResponseModel : ResponseModel
public DateTime Date { get; set; }
public Guid? TaskId { get; set; }
public DateTime? ReadDate { get; set; }
public DateTime? DeletedDate { get; set; }

View File

@ -31,7 +31,7 @@ using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.Identity.TokenProviders;
using Bit.Core.Tools.ImportFeatures;
using Bit.Core.Tools.ReportFeatures;
using Bit.Core.Auth.Models.Api.Request;
#if !OSS
using Bit.Commercial.Core.SecretsManager;
@ -176,6 +176,9 @@ public class Startup
services
.AddScoped<IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>>,
WebAuthnLoginKeyRotationValidator>();
services
.AddScoped<IRotationValidator<IEnumerable<OtherDeviceKeysUpdateRequestModel>, IEnumerable<Device>>,
DeviceRotationValidator>();
// Services
services.AddBaseServices(globalSettings);

View File

@ -77,10 +77,9 @@ public class ImportCiphersController : Controller
//An User is allowed to import if CanCreate Collections or has AccessToImportExport
var authorized = await CheckOrgImportPermission(collections, orgId);
if (!authorized)
{
throw new NotFoundException();
throw new BadRequestException("Not enough privileges to import into this organization.");
}
var userId = _userService.GetProperUserId(User).Value;
@ -103,21 +102,59 @@ public class ImportCiphersController : Controller
.Select(c => c.Id)
.ToHashSet();
//We need to verify if the user is trying to import into existing collections
var existingCollections = collections.Where(tc => orgCollectionIds.Contains(tc.Id));
//When importing into existing collection, we need to verify if the user has permissions
if (existingCollections.Any() && !(await _authorizationService.AuthorizeAsync(User, existingCollections, BulkCollectionOperations.ImportCiphers)).Succeeded)
// when there are no collections, then we can import
if (collections.Count == 0)
{
return false;
};
//Users allowed to import if they CanCreate Collections
if (!(await _authorizationService.AuthorizeAsync(User, collections, BulkCollectionOperations.Create)).Succeeded)
{
return false;
return true;
}
return true;
// are we trying to import into existing collections?
var existingCollections = collections.Where(tc => orgCollectionIds.Contains(tc.Id));
// are we trying to create new collections?
var hasNewCollections = collections.Any(tc => !orgCollectionIds.Contains(tc.Id));
// suppose we have both new and existing collections
if (hasNewCollections && existingCollections.Any())
{
// since we are creating new collection, user must have import/manage and create collection permission
if ((await _authorizationService.AuthorizeAsync(User, collections, BulkCollectionOperations.Create)).Succeeded
&& (await _authorizationService.AuthorizeAsync(User, existingCollections, BulkCollectionOperations.ImportCiphers)).Succeeded)
{
// can import collections and create new ones
return true;
}
else
{
// user does not have permission to import
return false;
}
}
// suppose we have new collections and none of our collections exist
if (hasNewCollections && !existingCollections.Any())
{
// user is trying to create new collections
// we need to check if the user has permission to create collections
if ((await _authorizationService.AuthorizeAsync(User, collections, BulkCollectionOperations.Create)).Succeeded)
{
return true;
}
else
{
// user does not have permission to create new collections
return false;
}
}
// in many import formats, we don't create collections, we just import ciphers into an existing collection
// When importing, we need to verify if the user has ImportCiphers permission
if (existingCollections.Any() && (await _authorizationService.AuthorizeAsync(User, existingCollections, BulkCollectionOperations.ImportCiphers)).Succeeded)
{
return true;
};
return false;
}
}

View File

@ -1,4 +1,5 @@
using Bit.Api.Tools.Authorization;
using Bit.Api.AdminConsole.Authorization;
using Bit.Api.Tools.Authorization;
using Bit.Api.Vault.AuthorizationHandlers.Collections;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Authorization;
using Bit.Core.IdentityServer;
@ -109,6 +110,8 @@ public static class ServiceCollectionExtensions
services.AddScoped<IAuthorizationHandler, VaultExportAuthorizationHandler>();
services.AddScoped<IAuthorizationHandler, SecurityTaskAuthorizationHandler>();
services.AddScoped<IAuthorizationHandler, SecurityTaskOrganizationAuthorizationHandler>();
services.AddScoped<IAuthorizationHandler, OrganizationRequirementHandler>();
}
public static void AddPhishingDomainServices(this IServiceCollection services, GlobalSettings globalSettings)

View File

@ -16,6 +16,7 @@ using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities;
using Bit.Core.Vault.Authorization.Permissions;
using Bit.Core.Vault.Entities;
using Bit.Core.Vault.Models.Data;
using Bit.Core.Vault.Queries;
@ -127,9 +128,9 @@ public class CiphersController : Controller
public async Task<ListResponseModel<CipherDetailsResponseModel>> Get()
{
var user = await _userService.GetUserByPrincipalAsync(User);
var hasOrgs = _currentContext.Organizations?.Any() ?? false;
var hasOrgs = _currentContext.Organizations.Count != 0;
// TODO: Use hasOrgs proper for cipher listing here?
var ciphers = await _cipherRepository.GetManyByUserIdAsync(user.Id, withOrganizations: true || hasOrgs);
var ciphers = await _cipherRepository.GetManyByUserIdAsync(user.Id, withOrganizations: true);
Dictionary<Guid, IGrouping<Guid, CollectionCipher>> collectionCiphersGroupDict = null;
if (hasOrgs)
{
@ -345,6 +346,77 @@ public class CiphersController : Controller
return await CanEditCiphersAsync(organizationId, cipherIds);
}
private async Task<bool> CanDeleteOrRestoreCipherAsAdminAsync(Guid organizationId, IEnumerable<Guid> cipherIds)
{
if (!_featureService.IsEnabled(FeatureFlagKeys.LimitItemDeletion))
{
return await CanEditCipherAsAdminAsync(organizationId, cipherIds);
}
var org = _currentContext.GetOrganization(organizationId);
// If we're not an "admin", we don't need to check the ciphers
if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true }))
{
// Are we a provider user? If so, we need to be sure we're not restricted
// Once the feature flag is removed, this check can be combined with the above
if (await _currentContext.ProviderUserForOrgAsync(organizationId))
{
// Provider is restricted from editing ciphers, so we're not an "admin"
if (_featureService.IsEnabled(FeatureFlagKeys.RestrictProviderAccess))
{
return false;
}
// Provider is unrestricted, so we're an "admin", don't return early
}
else
{
// Not a provider or admin
return false;
}
}
// If the user can edit all ciphers for the organization, just check they all belong to the org
if (await CanEditAllCiphersAsync(organizationId))
{
// TODO: This can likely be optimized to only query the requested ciphers and then checking they belong to the org
var orgCiphers = (await _cipherRepository.GetManyByOrganizationIdAsync(organizationId)).ToDictionary(c => c.Id);
// Ensure all requested ciphers are in orgCiphers
return cipherIds.All(c => orgCiphers.ContainsKey(c));
}
// The user cannot access any ciphers for the organization, we're done
if (!await CanAccessOrganizationCiphersAsync(organizationId))
{
return false;
}
var user = await _userService.GetUserByPrincipalAsync(User);
// Select all deletable ciphers for this user belonging to the organization
var deletableOrgCipherList = (await _cipherRepository.GetManyByUserIdAsync(user.Id, true))
.Where(c => c.OrganizationId == organizationId && c.UserId == null).ToList();
// Special case for unassigned ciphers
if (await CanAccessUnassignedCiphersAsync(organizationId))
{
var unassignedCiphers =
(await _cipherRepository.GetManyUnassignedOrganizationDetailsByOrganizationIdAsync(
organizationId));
// Users that can access unassigned ciphers can also delete them
deletableOrgCipherList.AddRange(unassignedCiphers.Select(c => new CipherDetails(c) { Manage = true }));
}
var organizationAbility = await _applicationCacheService.GetOrganizationAbilityAsync(organizationId);
var deletableOrgCiphers = deletableOrgCipherList
.Where(c => NormalCipherPermissions.CanDelete(user, c, organizationAbility))
.ToDictionary(c => c.Id);
return cipherIds.All(c => deletableOrgCiphers.ContainsKey(c));
}
/// <summary>
/// TODO: Move this to its own authorization handler or equivalent service - AC-2062
/// </summary>
@ -763,12 +835,12 @@ public class CiphersController : Controller
[HttpDelete("{id}/admin")]
[HttpPost("{id}/delete-admin")]
public async Task DeleteAdmin(string id)
public async Task DeleteAdmin(Guid id)
{
var userId = _userService.GetProperUserId(User).Value;
var cipher = await _cipherRepository.GetByIdAsync(new Guid(id));
var cipher = await GetByIdAsync(id, userId);
if (cipher == null || !cipher.OrganizationId.HasValue ||
!await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id }))
!await CanDeleteOrRestoreCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id }))
{
throw new NotFoundException();
}
@ -808,7 +880,7 @@ public class CiphersController : Controller
var cipherIds = model.Ids.Select(i => new Guid(i)).ToList();
if (string.IsNullOrWhiteSpace(model.OrganizationId) ||
!await CanEditCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds))
!await CanDeleteOrRestoreCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds))
{
throw new NotFoundException();
}
@ -830,12 +902,12 @@ public class CiphersController : Controller
}
[HttpPut("{id}/delete-admin")]
public async Task PutDeleteAdmin(string id)
public async Task PutDeleteAdmin(Guid id)
{
var userId = _userService.GetProperUserId(User).Value;
var cipher = await _cipherRepository.GetByIdAsync(new Guid(id));
var cipher = await GetByIdAsync(id, userId);
if (cipher == null || !cipher.OrganizationId.HasValue ||
!await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id }))
!await CanDeleteOrRestoreCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id }))
{
throw new NotFoundException();
}
@ -871,7 +943,7 @@ public class CiphersController : Controller
var cipherIds = model.Ids.Select(i => new Guid(i)).ToList();
if (string.IsNullOrWhiteSpace(model.OrganizationId) ||
!await CanEditCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds))
!await CanDeleteOrRestoreCipherAsAdminAsync(new Guid(model.OrganizationId), cipherIds))
{
throw new NotFoundException();
}
@ -899,12 +971,12 @@ public class CiphersController : Controller
}
[HttpPut("{id}/restore-admin")]
public async Task<CipherMiniResponseModel> PutRestoreAdmin(string id)
public async Task<CipherMiniResponseModel> PutRestoreAdmin(Guid id)
{
var userId = _userService.GetProperUserId(User).Value;
var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id));
var cipher = await GetByIdAsync(id, userId);
if (cipher == null || !cipher.OrganizationId.HasValue ||
!await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id }))
!await CanDeleteOrRestoreCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id }))
{
throw new NotFoundException();
}
@ -944,7 +1016,7 @@ public class CiphersController : Controller
var cipherIdsToRestore = new HashSet<Guid>(model.Ids.Select(i => new Guid(i)));
if (model.OrganizationId == default || !await CanEditCipherAsAdminAsync(model.OrganizationId, cipherIdsToRestore))
if (model.OrganizationId == default || !await CanDeleteOrRestoreCipherAsAdminAsync(model.OrganizationId, cipherIdsToRestore))
{
throw new NotFoundException();
}
@ -1019,9 +1091,9 @@ public class CiphersController : Controller
throw new BadRequestException(ModelState);
}
// If Account Deprovisioning is enabled, we need to check if the user is managed by any organization.
// If Account Deprovisioning is enabled, we need to check if the user is claimed by any organization.
if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning)
&& await _userService.IsManagedByAnyOrganizationAsync(user.Id))
&& await _userService.IsClaimedByAnyOrganizationAsync(user.Id))
{
throw new BadRequestException("Cannot purge accounts owned by an organization. Contact your organization administrator for additional details.");
}

View File

@ -5,6 +5,7 @@ using Bit.Core;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Bit.Core.Vault.Commands.Interfaces;
using Bit.Core.Vault.Entities;
using Bit.Core.Vault.Enums;
using Bit.Core.Vault.Queries;
using Microsoft.AspNetCore.Authorization;
@ -89,11 +90,28 @@ public class SecurityTaskController : Controller
public async Task<ListResponseModel<SecurityTasksResponseModel>> BulkCreateTasks(Guid orgId,
[FromBody] BulkCreateSecurityTasksRequestModel model)
{
var securityTasks = await _createManyTasksCommand.CreateAsync(orgId, model.Tasks);
// Retrieve existing pending security tasks for the organization
var pendingSecurityTasks = await _getTasksForOrganizationQuery.GetTasksAsync(orgId, SecurityTaskStatus.Pending);
await _createManyTaskNotificationsCommand.CreateAsync(orgId, securityTasks);
// Get the security tasks that are already associated with a cipher within the submitted model
var existingTasks = pendingSecurityTasks.Where(x => model.Tasks.Any(y => y.CipherId == x.CipherId)).ToList();
var response = securityTasks.Select(x => new SecurityTasksResponseModel(x)).ToList();
// Get tasks that need to be created
var tasksToCreateFromModel = model.Tasks.Where(x => !existingTasks.Any(y => y.CipherId == x.CipherId)).ToList();
ICollection<SecurityTask> newSecurityTasks = new List<SecurityTask>();
if (tasksToCreateFromModel.Count != 0)
{
newSecurityTasks = await _createManyTasksCommand.CreateAsync(orgId, tasksToCreateFromModel);
}
// Combine existing tasks and newly created tasks
var allTasks = existingTasks.Concat(newSecurityTasks);
await _createManyTaskNotificationsCommand.CreateAsync(orgId, allTasks);
var response = allTasks.Select(x => new SecurityTasksResponseModel(x)).ToList();
return new ListResponseModel<SecurityTasksResponseModel>(response);
}
}

View File

@ -104,13 +104,13 @@ public class SyncController : Controller
var userTwoFactorEnabled = await _userService.TwoFactorIsEnabledAsync(user);
var userHasPremiumFromOrganization = await _userService.HasPremiumFromOrganization(user);
var organizationManagingActiveUser = await _userService.GetOrganizationsManagingUserAsync(user.Id);
var organizationIdsManagingActiveUser = organizationManagingActiveUser.Select(o => o.Id);
var organizationClaimingActiveUser = await _userService.GetOrganizationsClaimingUserAsync(user.Id);
var organizationIdsClaimingActiveUser = organizationClaimingActiveUser.Select(o => o.Id);
var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var response = new SyncResponseModel(_globalSettings, user, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationAbilities,
organizationIdsManagingActiveUser, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails,
organizationIdsClaimingActiveUser, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails,
folders, collections, ciphers, collectionCiphersGroupDict, excludeDomains, policies, sends);
return response;
}

View File

@ -23,7 +23,7 @@ public class SyncResponseModel : ResponseModel
bool userTwoFactorEnabled,
bool userHasPremiumFromOrganization,
IDictionary<Guid, OrganizationAbility> organizationAbilities,
IEnumerable<Guid> organizationIdsManagingUser,
IEnumerable<Guid> organizationIdsClaimingingUser,
IEnumerable<OrganizationUserOrganizationDetails> organizationUserDetails,
IEnumerable<ProviderUserProviderDetails> providerUserDetails,
IEnumerable<ProviderUserOrganizationDetails> providerUserOrganizationDetails,
@ -37,7 +37,7 @@ public class SyncResponseModel : ResponseModel
: base("sync")
{
Profile = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails,
providerUserOrganizationDetails, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsManagingUser);
providerUserOrganizationDetails, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsClaimingingUser);
Folders = folders.Select(f => new FolderResponseModel(f));
Ciphers = ciphers.Select(cipher =>
new CipherDetailsResponseModel(

View File

@ -2,9 +2,6 @@
<PropertyGroup>
<UserSecretsId>bitwarden-Billing</UserSecretsId>
<MvcRazorCompileOnPublish>false</MvcRazorCompileOnPublish>
<!-- Temp exclusions until warnings are fixed -->
<WarningsNotAsErrors>$(WarningsNotAsErrors);CS9113</WarningsNotAsErrors>
</PropertyGroup>
<PropertyGroup Condition=" '$(RunConfiguration)' == 'Billing' " />

View File

@ -1,6 +1,5 @@
using Bit.Billing.Constants;
using Bit.Billing.Jobs;
using Bit.Core;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.Billing.Pricing;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
@ -24,7 +23,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
private readonly IPushNotificationService _pushNotificationService;
private readonly IOrganizationRepository _organizationRepository;
private readonly ISchedulerFactory _schedulerFactory;
private readonly IFeatureService _featureService;
private readonly IOrganizationEnableCommand _organizationEnableCommand;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IPricingClient _pricingClient;
@ -39,7 +37,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
IPushNotificationService pushNotificationService,
IOrganizationRepository organizationRepository,
ISchedulerFactory schedulerFactory,
IFeatureService featureService,
IOrganizationEnableCommand organizationEnableCommand,
IOrganizationDisableCommand organizationDisableCommand,
IPricingClient pricingClient)
@ -53,7 +50,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
_pushNotificationService = pushNotificationService;
_organizationRepository = organizationRepository;
_schedulerFactory = schedulerFactory;
_featureService = featureService;
_organizationEnableCommand = organizationEnableCommand;
_organizationDisableCommand = organizationDisableCommand;
_pricingClient = pricingClient;
@ -227,12 +223,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
private async Task ScheduleCancellationJobAsync(string subscriptionId, Guid organizationId)
{
var isResellerManagedOrgAlertEnabled = _featureService.IsEnabled(FeatureFlagKeys.ResellerManagedOrgAlert);
if (!isResellerManagedOrgAlertEnabled)
{
return;
}
var scheduler = await _schedulerFactory.GetScheduler();
var job = JobBuilder.Create<SubscriptionCancellationJob>()

View File

@ -1,8 +1,11 @@
using Bit.Core.AdminConsole.Repositories;
using Bit.Core;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Services;
@ -12,6 +15,7 @@ using Event = Stripe.Event;
namespace Bit.Billing.Services.Implementations;
public class UpcomingInvoiceHandler(
IFeatureService featureService,
ILogger<StripeEventProcessor> logger,
IMailService mailService,
IOrganizationRepository organizationRepository,
@ -21,7 +25,8 @@ public class UpcomingInvoiceHandler(
IStripeEventService stripeEventService,
IStripeEventUtilityService stripeEventUtilityService,
IUserRepository userRepository,
IValidateSponsorshipCommand validateSponsorshipCommand)
IValidateSponsorshipCommand validateSponsorshipCommand,
IAutomaticTaxFactory automaticTaxFactory)
: IUpcomingInvoiceHandler
{
public async Task HandleAsync(Event parsedEvent)
@ -136,6 +141,21 @@ public class UpcomingInvoiceHandler(
private async Task TryEnableAutomaticTaxAsync(Subscription subscription)
{
if (featureService.IsEnabled(FeatureFlagKeys.PM19147_AutomaticTaxImprovements))
{
var automaticTaxParameters = new AutomaticTaxFactoryParameters(subscription.Items.Select(x => x.Price.Id));
var automaticTaxStrategy = await automaticTaxFactory.CreateAsync(automaticTaxParameters);
var updateOptions = automaticTaxStrategy.GetUpdateOptions(subscription);
if (updateOptions == null)
{
return;
}
await stripeFacade.UpdateSubscription(subscription.Id, updateOptions);
return;
}
if (subscription.AutomaticTax.Enabled ||
!subscription.Customer.HasBillingLocation() ||
await IsNonTaxableNonUSBusinessUseSubscription(subscription))

Some files were not shown because too many files have changed in this diff Show More