1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-05 05:00:19 -05:00

Merge branch 'main' into ac/pm-15621/refactor-delete-command

This commit is contained in:
Jimmy Vo 2025-03-24 10:26:55 -04:00
commit 379cddcef9
No known key found for this signature in database
GPG Key ID: 7CB834D6F4FFCA11
128 changed files with 12747 additions and 2477 deletions

View File

@ -1,5 +1,3 @@
version: '3'
services: services:
bitwarden_server: bitwarden_server:
image: mcr.microsoft.com/devcontainers/dotnet:8.0 image: mcr.microsoft.com/devcontainers/dotnet:8.0
@ -13,7 +11,8 @@ services:
platform: linux/amd64 platform: linux/amd64
restart: unless-stopped restart: unless-stopped
env_file: env_file:
../../dev/.env - path: ../../dev/.env
required: false
environment: environment:
ACCEPT_EULA: "Y" ACCEPT_EULA: "Y"
MSSQL_PID: Developer MSSQL_PID: Developer

View File

@ -51,4 +51,10 @@ Proceed? [y/N] " response
} }
# main # main
one_time_setup if [[ -z "${CODESPACES}" ]]; then
one_time_setup
else
# Ignore interactive elements when running in codespaces since they are not supported there
# TODO Write codespaces specific instructions and link here
echo "Running in codespaces, follow instructions here: https://contributing.bitwarden.com/getting-started/server/guide/ to continue the setup"
fi

View File

@ -1,5 +1,3 @@
version: '3'
services: services:
bitwarden_storage: bitwarden_storage:
image: mcr.microsoft.com/azure-storage/azurite:latest image: mcr.microsoft.com/azure-storage/azurite:latest

View File

@ -89,4 +89,10 @@ install_stripe_cli() {
} }
# main # main
one_time_setup if [[ -z "${CODESPACES}" ]]; then
one_time_setup
else
# Ignore interactive elements when running in codespaces since they are not supported there
# TODO Write codespaces specific instructions and link here
echo "Running in codespaces, follow instructions here: https://contributing.bitwarden.com/getting-started/server/guide/ to continue the setup"
fi

View File

@ -32,28 +32,9 @@ on:
- "src/**/Entities/**/*.cs" # Database entity definitions - "src/**/Entities/**/*.cs" # Database entity definitions
jobs: jobs:
check-test-secrets:
name: Check for test secrets
runs-on: ubuntu-22.04
outputs:
available: ${{ steps.check-test-secrets.outputs.available }}
permissions:
contents: read
steps:
- name: Check
id: check-test-secrets
run: |
if [ "${{ secrets.CODECOV_TOKEN }}" != '' ]; then
echo "available=true" >> $GITHUB_OUTPUT;
else
echo "available=false" >> $GITHUB_OUTPUT;
fi
test: test:
name: Run tests name: Run tests
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
needs: check-test-secrets
steps: steps:
- name: Check out repo - name: Check out repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -166,8 +147,8 @@ jobs:
run: 'docker logs $(docker ps --quiet --filter "name=mssql")' run: 'docker logs $(docker ps --quiet --filter "name=mssql")'
- name: Report test results - name: Report test results
uses: dorny/test-reporter@31a54ee7ebcacc03a09ea97a7e5465a47b84aea5 # v1.9.1 uses: dorny/test-reporter@6e6a65b7a0bd2c9197df7d0ae36ac5cee784230c # v2.0.0
if: ${{ needs.check-test-secrets.outputs.available == 'true' && !cancelled() }} if: ${{ github.event.pull_request.head.repo.full_name == github.repository && !cancelled() }}
with: with:
name: Test Results name: Test Results
path: "**/*-test-results.trx" path: "**/*-test-results.trx"

View File

@ -13,29 +13,10 @@ env:
_AZ_REGISTRY: "bitwardenprod.azurecr.io" _AZ_REGISTRY: "bitwardenprod.azurecr.io"
jobs: jobs:
check-test-secrets:
name: Check for test secrets
runs-on: ubuntu-22.04
outputs:
available: ${{ steps.check-test-secrets.outputs.available }}
permissions:
contents: read
steps:
- name: Check
id: check-test-secrets
run: |
if [ "${{ secrets.CODECOV_TOKEN }}" != '' ]; then
echo "available=true" >> $GITHUB_OUTPUT;
else
echo "available=false" >> $GITHUB_OUTPUT;
fi
testing: testing:
name: Run tests name: Run tests
if: ${{ startsWith(github.head_ref, 'version_bump_') == false }} if: ${{ startsWith(github.head_ref, 'version_bump_') == false }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
needs: check-test-secrets
permissions: permissions:
checks: write checks: write
contents: read contents: read
@ -68,8 +49,8 @@ jobs:
run: dotnet test ./bitwarden_license/test --configuration Debug --logger "trx;LogFileName=bw-test-results.trx" /p:CoverletOutputFormatter="cobertura" --collect:"XPlat Code Coverage" run: dotnet test ./bitwarden_license/test --configuration Debug --logger "trx;LogFileName=bw-test-results.trx" /p:CoverletOutputFormatter="cobertura" --collect:"XPlat Code Coverage"
- name: Report test results - name: Report test results
uses: dorny/test-reporter@31a54ee7ebcacc03a09ea97a7e5465a47b84aea5 # v1.9.1 uses: dorny/test-reporter@6e6a65b7a0bd2c9197df7d0ae36ac5cee784230c # v2.0.0
if: ${{ needs.check-test-secrets.outputs.available == 'true' && !cancelled() }} if: ${{ github.event.pull_request.head.repo.full_name == github.repository && !cancelled() }}
with: with:
name: Test Results name: Test Results
path: "**/*-test-results.trx" path: "**/*-test-results.trx"

View File

@ -3,7 +3,7 @@
<PropertyGroup> <PropertyGroup>
<TargetFramework>net8.0</TargetFramework> <TargetFramework>net8.0</TargetFramework>
<Version>2025.2.3</Version> <Version>2025.3.6</Version>
<RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace> <RootNamespace>Bit.$(MSBuildProjectName)</RootNamespace>
<ImplicitUsings>enable</ImplicitUsings> <ImplicitUsings>enable</ImplicitUsings>

View File

@ -628,6 +628,19 @@ public class ProviderBillingService(
} }
} }
public async Task UpdatePaymentMethod(
Provider provider,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation)
{
await Task.WhenAll(
subscriberService.UpdatePaymentSource(provider, tokenizedPaymentSource),
subscriberService.UpdateTaxInformation(provider, taxInformation));
await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId,
new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically });
}
public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command) public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command)
{ {
if (command.Configuration.Any(x => x.SeatsMinimum < 0)) if (command.Configuration.Any(x => x.SeatsMinimum < 0))

View File

@ -11,6 +11,7 @@ using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Entities; using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Billing.Services.Contracts; using Bit.Core.Billing.Services.Contracts;
@ -42,6 +43,7 @@ public class ProvidersController : Controller
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
private readonly IProviderPlanRepository _providerPlanRepository; private readonly IProviderPlanRepository _providerPlanRepository;
private readonly IProviderBillingService _providerBillingService; private readonly IProviderBillingService _providerBillingService;
private readonly IPricingClient _pricingClient;
private readonly string _stripeUrl; private readonly string _stripeUrl;
private readonly string _braintreeMerchantUrl; private readonly string _braintreeMerchantUrl;
private readonly string _braintreeMerchantId; private readonly string _braintreeMerchantId;
@ -60,7 +62,8 @@ public class ProvidersController : Controller
IFeatureService featureService, IFeatureService featureService,
IProviderPlanRepository providerPlanRepository, IProviderPlanRepository providerPlanRepository,
IProviderBillingService providerBillingService, IProviderBillingService providerBillingService,
IWebHostEnvironment webHostEnvironment) IWebHostEnvironment webHostEnvironment,
IPricingClient pricingClient)
{ {
_organizationRepository = organizationRepository; _organizationRepository = organizationRepository;
_organizationService = organizationService; _organizationService = organizationService;
@ -75,6 +78,7 @@ public class ProvidersController : Controller
_featureService = featureService; _featureService = featureService;
_providerPlanRepository = providerPlanRepository; _providerPlanRepository = providerPlanRepository;
_providerBillingService = providerBillingService; _providerBillingService = providerBillingService;
_pricingClient = pricingClient;
_stripeUrl = webHostEnvironment.GetStripeUrl(); _stripeUrl = webHostEnvironment.GetStripeUrl();
_braintreeMerchantUrl = webHostEnvironment.GetBraintreeMerchantUrl(); _braintreeMerchantUrl = webHostEnvironment.GetBraintreeMerchantUrl();
_braintreeMerchantId = globalSettings.Braintree.MerchantId; _braintreeMerchantId = globalSettings.Braintree.MerchantId;
@ -415,7 +419,9 @@ public class ProvidersController : Controller
return RedirectToAction("Index"); return RedirectToAction("Index");
} }
return View(new OrganizationEditModel(provider)); var plans = await _pricingClient.ListPlans();
return View(new OrganizationEditModel(provider, plans));
} }
[HttpPost] [HttpPost]

View File

@ -22,13 +22,14 @@ public class OrganizationEditModel : OrganizationViewModel
public OrganizationEditModel() { } public OrganizationEditModel() { }
public OrganizationEditModel(Provider provider) public OrganizationEditModel(Provider provider, List<Plan> plans)
{ {
Provider = provider; Provider = provider;
BillingEmail = provider.Type == ProviderType.Reseller ? provider.BillingEmail : string.Empty; BillingEmail = provider.Type == ProviderType.Reseller ? provider.BillingEmail : string.Empty;
PlanType = Core.Billing.Enums.PlanType.TeamsMonthly; PlanType = Core.Billing.Enums.PlanType.TeamsMonthly;
Plan = Core.Billing.Enums.PlanType.TeamsMonthly.GetDisplayAttribute()?.GetName(); Plan = Core.Billing.Enums.PlanType.TeamsMonthly.GetDisplayAttribute()?.GetName();
LicenseKey = RandomLicenseKey; LicenseKey = RandomLicenseKey;
_plans = plans;
} }
public OrganizationEditModel( public OrganizationEditModel(

View File

@ -9,6 +9,8 @@ using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Authorization; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Authorization;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.OrganizationFeatures.Shared.Authorization; using Bit.Core.AdminConsole.OrganizationFeatures.Shared.Authorization;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Enums; using Bit.Core.Auth.Enums;
@ -56,6 +58,7 @@ public class OrganizationUsersController : Controller
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
private readonly IDeleteManagedOrganizationUserAccountCommand _deleteManagedOrganizationUserAccountCommand; private readonly IDeleteManagedOrganizationUserAccountCommand _deleteManagedOrganizationUserAccountCommand;
private readonly IGetOrganizationUsersManagementStatusQuery _getOrganizationUsersManagementStatusQuery; private readonly IGetOrganizationUsersManagementStatusQuery _getOrganizationUsersManagementStatusQuery;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
@ -80,6 +83,7 @@ public class OrganizationUsersController : Controller
IRemoveOrganizationUserCommand removeOrganizationUserCommand, IRemoveOrganizationUserCommand removeOrganizationUserCommand,
IDeleteManagedOrganizationUserAccountCommand deleteManagedOrganizationUserAccountCommand, IDeleteManagedOrganizationUserAccountCommand deleteManagedOrganizationUserAccountCommand,
IGetOrganizationUsersManagementStatusQuery getOrganizationUsersManagementStatusQuery, IGetOrganizationUsersManagementStatusQuery getOrganizationUsersManagementStatusQuery,
IPolicyRequirementQuery policyRequirementQuery,
IFeatureService featureService, IFeatureService featureService,
IPricingClient pricingClient) IPricingClient pricingClient)
{ {
@ -103,6 +107,7 @@ public class OrganizationUsersController : Controller
_removeOrganizationUserCommand = removeOrganizationUserCommand; _removeOrganizationUserCommand = removeOrganizationUserCommand;
_deleteManagedOrganizationUserAccountCommand = deleteManagedOrganizationUserAccountCommand; _deleteManagedOrganizationUserAccountCommand = deleteManagedOrganizationUserAccountCommand;
_getOrganizationUsersManagementStatusQuery = getOrganizationUsersManagementStatusQuery; _getOrganizationUsersManagementStatusQuery = getOrganizationUsersManagementStatusQuery;
_policyRequirementQuery = policyRequirementQuery;
_featureService = featureService; _featureService = featureService;
_pricingClient = pricingClient; _pricingClient = pricingClient;
} }
@ -316,11 +321,13 @@ public class OrganizationUsersController : Controller
throw new UnauthorizedAccessException(); throw new UnauthorizedAccessException();
} }
var useMasterPasswordPolicy = await ShouldHandleResetPasswordAsync(orgId); var useMasterPasswordPolicy = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements)
? (await _policyRequirementQuery.GetAsync<ResetPasswordPolicyRequirement>(user.Id)).AutoEnrollEnabled(orgId)
: await ShouldHandleResetPasswordAsync(orgId);
if (useMasterPasswordPolicy && string.IsNullOrWhiteSpace(model.ResetPasswordKey)) if (useMasterPasswordPolicy && string.IsNullOrWhiteSpace(model.ResetPasswordKey))
{ {
throw new BadRequestException(string.Empty, "Master Password reset is required, but not provided."); throw new BadRequestException("Master Password reset is required, but not provided.");
} }
await _acceptOrgUserCommand.AcceptOrgUserByEmailTokenAsync(organizationUserId, user, model.Token, _userService); await _acceptOrgUserCommand.AcceptOrgUserByEmailTokenAsync(organizationUserId, user, model.Token, _userService);

View File

@ -16,6 +16,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationApiKeys.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Enums; using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Repositories; using Bit.Core.Auth.Repositories;
@ -61,6 +63,7 @@ public class OrganizationsController : Controller
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
private readonly ICloudOrganizationSignUpCommand _cloudOrganizationSignUpCommand; private readonly ICloudOrganizationSignUpCommand _cloudOrganizationSignUpCommand;
private readonly IOrganizationDeleteCommand _organizationDeleteCommand; private readonly IOrganizationDeleteCommand _organizationDeleteCommand;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
public OrganizationsController( public OrganizationsController(
@ -84,6 +87,7 @@ public class OrganizationsController : Controller
IRemoveOrganizationUserCommand removeOrganizationUserCommand, IRemoveOrganizationUserCommand removeOrganizationUserCommand,
ICloudOrganizationSignUpCommand cloudOrganizationSignUpCommand, ICloudOrganizationSignUpCommand cloudOrganizationSignUpCommand,
IOrganizationDeleteCommand organizationDeleteCommand, IOrganizationDeleteCommand organizationDeleteCommand,
IPolicyRequirementQuery policyRequirementQuery,
IPricingClient pricingClient) IPricingClient pricingClient)
{ {
_organizationRepository = organizationRepository; _organizationRepository = organizationRepository;
@ -106,6 +110,7 @@ public class OrganizationsController : Controller
_removeOrganizationUserCommand = removeOrganizationUserCommand; _removeOrganizationUserCommand = removeOrganizationUserCommand;
_cloudOrganizationSignUpCommand = cloudOrganizationSignUpCommand; _cloudOrganizationSignUpCommand = cloudOrganizationSignUpCommand;
_organizationDeleteCommand = organizationDeleteCommand; _organizationDeleteCommand = organizationDeleteCommand;
_policyRequirementQuery = policyRequirementQuery;
_pricingClient = pricingClient; _pricingClient = pricingClient;
} }
@ -163,8 +168,13 @@ public class OrganizationsController : Controller
throw new NotFoundException(); throw new NotFoundException();
} }
var resetPasswordPolicy = if (_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements))
await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); {
var resetPasswordPolicyRequirement = await _policyRequirementQuery.GetAsync<ResetPasswordPolicyRequirement>(user.Id);
return new OrganizationAutoEnrollStatusResponseModel(organization.Id, resetPasswordPolicyRequirement.AutoEnrollEnabled(organization.Id));
}
var resetPasswordPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword);
if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled || resetPasswordPolicy.Data == null) if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled || resetPasswordPolicy.Data == null)
{ {
return new OrganizationAutoEnrollStatusResponseModel(organization.Id, false); return new OrganizationAutoEnrollStatusResponseModel(organization.Id, false);
@ -172,6 +182,7 @@ public class OrganizationsController : Controller
var data = JsonSerializer.Deserialize<ResetPasswordDataModel>(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); var data = JsonSerializer.Deserialize<ResetPasswordDataModel>(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase);
return new OrganizationAutoEnrollStatusResponseModel(organization.Id, data?.AutoEnrollEnabled ?? false); return new OrganizationAutoEnrollStatusResponseModel(organization.Id, data?.AutoEnrollEnabled ?? false);
} }
[HttpPost("")] [HttpPost("")]

View File

@ -23,6 +23,7 @@ public class PendingOrganizationAuthRequestResponseModel : ResponseModel
RequestDeviceType = authRequest.RequestDeviceType.GetType().GetMember(authRequest.RequestDeviceType.ToString()) RequestDeviceType = authRequest.RequestDeviceType.GetType().GetMember(authRequest.RequestDeviceType.ToString())
.FirstOrDefault()?.GetCustomAttribute<DisplayAttribute>()?.GetName(); .FirstOrDefault()?.GetCustomAttribute<DisplayAttribute>()?.GetName();
RequestIpAddress = authRequest.RequestIpAddress; RequestIpAddress = authRequest.RequestIpAddress;
RequestCountryName = authRequest.RequestCountryName;
CreationDate = authRequest.CreationDate; CreationDate = authRequest.CreationDate;
} }
@ -34,5 +35,6 @@ public class PendingOrganizationAuthRequestResponseModel : ResponseModel
public string RequestDeviceIdentifier { get; set; } public string RequestDeviceIdentifier { get; set; }
public string RequestDeviceType { get; set; } public string RequestDeviceType { get; set; }
public string RequestIpAddress { get; set; } public string RequestIpAddress { get; set; }
public string RequestCountryName { get; set; }
public DateTime CreationDate { get; set; } public DateTime CreationDate { get; set; }
} }

View File

@ -4,11 +4,9 @@ using Bit.Api.Auth.Models.Request;
using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Auth.Models.Request.Accounts;
using Bit.Api.Auth.Models.Request.WebAuthn; using Bit.Api.Auth.Models.Request.WebAuthn;
using Bit.Api.KeyManagement.Validators; using Bit.Api.KeyManagement.Validators;
using Bit.Api.Models.Request;
using Bit.Api.Models.Request.Accounts; using Bit.Api.Models.Request.Accounts;
using Bit.Api.Models.Response; using Bit.Api.Models.Response;
using Bit.Api.Tools.Models.Request; using Bit.Api.Tools.Models.Request;
using Bit.Api.Utilities;
using Bit.Api.Vault.Models.Request; using Bit.Api.Vault.Models.Request;
using Bit.Core; using Bit.Core;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
@ -19,23 +17,15 @@ using Bit.Core.Auth.Models.Api.Request.Accounts;
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.UserFeatures.TdeOffboardingPassword.Interfaces; using Bit.Core.Auth.UserFeatures.TdeOffboardingPassword.Interfaces;
using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces; using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.KeyManagement.Models.Data; using Bit.Core.KeyManagement.Models.Data;
using Bit.Core.KeyManagement.UserKey; using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Models.Api.Response; using Bit.Core.Models.Api.Response;
using Bit.Core.Models.Business;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities; using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
@ -47,20 +37,15 @@ namespace Bit.Api.Auth.Controllers;
[Authorize("Application")] [Authorize("Application")]
public class AccountsController : Controller public class AccountsController : Controller
{ {
private readonly GlobalSettings _globalSettings;
private readonly IOrganizationService _organizationService; private readonly IOrganizationService _organizationService;
private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IProviderUserRepository _providerUserRepository; private readonly IProviderUserRepository _providerUserRepository;
private readonly IPaymentService _paymentService;
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly IPolicyService _policyService; private readonly IPolicyService _policyService;
private readonly ISetInitialMasterPasswordCommand _setInitialMasterPasswordCommand; private readonly ISetInitialMasterPasswordCommand _setInitialMasterPasswordCommand;
private readonly ITdeOffboardingPasswordCommand _tdeOffboardingPasswordCommand; private readonly ITdeOffboardingPasswordCommand _tdeOffboardingPasswordCommand;
private readonly IRotateUserKeyCommand _rotateUserKeyCommand; private readonly IRotateUserKeyCommand _rotateUserKeyCommand;
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
private readonly ISubscriberService _subscriberService;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
private readonly IRotationValidator<IEnumerable<CipherWithIdRequestModel>, IEnumerable<Cipher>> _cipherValidator; private readonly IRotationValidator<IEnumerable<CipherWithIdRequestModel>, IEnumerable<Cipher>> _cipherValidator;
private readonly IRotationValidator<IEnumerable<FolderWithIdRequestModel>, IEnumerable<Folder>> _folderValidator; private readonly IRotationValidator<IEnumerable<FolderWithIdRequestModel>, IEnumerable<Folder>> _folderValidator;
@ -75,20 +60,15 @@ public class AccountsController : Controller
public AccountsController( public AccountsController(
GlobalSettings globalSettings,
IOrganizationService organizationService, IOrganizationService organizationService,
IOrganizationUserRepository organizationUserRepository, IOrganizationUserRepository organizationUserRepository,
IProviderUserRepository providerUserRepository, IProviderUserRepository providerUserRepository,
IPaymentService paymentService,
IUserService userService, IUserService userService,
IPolicyService policyService, IPolicyService policyService,
ISetInitialMasterPasswordCommand setInitialMasterPasswordCommand, ISetInitialMasterPasswordCommand setInitialMasterPasswordCommand,
ITdeOffboardingPasswordCommand tdeOffboardingPasswordCommand, ITdeOffboardingPasswordCommand tdeOffboardingPasswordCommand,
IRotateUserKeyCommand rotateUserKeyCommand, IRotateUserKeyCommand rotateUserKeyCommand,
IFeatureService featureService, IFeatureService featureService,
ISubscriberService subscriberService,
IReferenceEventService referenceEventService,
ICurrentContext currentContext,
IRotationValidator<IEnumerable<CipherWithIdRequestModel>, IEnumerable<Cipher>> cipherValidator, IRotationValidator<IEnumerable<CipherWithIdRequestModel>, IEnumerable<Cipher>> cipherValidator,
IRotationValidator<IEnumerable<FolderWithIdRequestModel>, IEnumerable<Folder>> folderValidator, IRotationValidator<IEnumerable<FolderWithIdRequestModel>, IEnumerable<Folder>> folderValidator,
IRotationValidator<IEnumerable<SendWithIdRequestModel>, IReadOnlyList<Send>> sendValidator, IRotationValidator<IEnumerable<SendWithIdRequestModel>, IReadOnlyList<Send>> sendValidator,
@ -99,20 +79,15 @@ public class AccountsController : Controller
IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>> webAuthnKeyValidator IRotationValidator<IEnumerable<WebAuthnLoginRotateKeyRequestModel>, IEnumerable<WebAuthnLoginRotateKeyData>> webAuthnKeyValidator
) )
{ {
_globalSettings = globalSettings;
_organizationService = organizationService; _organizationService = organizationService;
_organizationUserRepository = organizationUserRepository; _organizationUserRepository = organizationUserRepository;
_providerUserRepository = providerUserRepository; _providerUserRepository = providerUserRepository;
_paymentService = paymentService;
_userService = userService; _userService = userService;
_policyService = policyService; _policyService = policyService;
_setInitialMasterPasswordCommand = setInitialMasterPasswordCommand; _setInitialMasterPasswordCommand = setInitialMasterPasswordCommand;
_tdeOffboardingPasswordCommand = tdeOffboardingPasswordCommand; _tdeOffboardingPasswordCommand = tdeOffboardingPasswordCommand;
_rotateUserKeyCommand = rotateUserKeyCommand; _rotateUserKeyCommand = rotateUserKeyCommand;
_featureService = featureService; _featureService = featureService;
_subscriberService = subscriberService;
_referenceEventService = referenceEventService;
_currentContext = currentContext;
_cipherValidator = cipherValidator; _cipherValidator = cipherValidator;
_folderValidator = folderValidator; _folderValidator = folderValidator;
_sendValidator = sendValidator; _sendValidator = sendValidator;
@ -638,212 +613,6 @@ public class AccountsController : Controller
throw new BadRequestException(ModelState); throw new BadRequestException(ModelState);
} }
[HttpPost("premium")]
public async Task<PaymentResponseModel> PostPremium(PremiumRequestModel model)
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var valid = model.Validate(_globalSettings);
UserLicense license = null;
if (valid && _globalSettings.SelfHosted)
{
license = await ApiHelpers.ReadJsonFileFromBody<UserLicense>(HttpContext, model.License);
}
if (!valid && !_globalSettings.SelfHosted && string.IsNullOrWhiteSpace(model.Country))
{
throw new BadRequestException("Country is required.");
}
if (!valid || (_globalSettings.SelfHosted && license == null))
{
throw new BadRequestException("Invalid license.");
}
var result = await _userService.SignUpPremiumAsync(user, model.PaymentToken,
model.PaymentMethodType.Value, model.AdditionalStorageGb.GetValueOrDefault(0), license,
new TaxInfo
{
BillingAddressCountry = model.Country,
BillingAddressPostalCode = model.PostalCode
});
var userTwoFactorEnabled = await _userService.TwoFactorIsEnabledAsync(user);
var userHasPremiumFromOrganization = await _userService.HasPremiumFromOrganization(user);
var organizationIdsManagingActiveUser = await GetOrganizationIdsManagingUserAsync(user.Id);
var profile = new ProfileResponseModel(user, null, null, null, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsManagingActiveUser);
return new PaymentResponseModel
{
UserProfile = profile,
PaymentIntentClientSecret = result.Item2,
Success = result.Item1
};
}
[HttpGet("subscription")]
public async Task<SubscriptionResponseModel> GetSubscription()
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
if (!_globalSettings.SelfHosted && user.Gateway != null)
{
var subscriptionInfo = await _paymentService.GetSubscriptionAsync(user);
var license = await _userService.GenerateLicenseAsync(user, subscriptionInfo);
return new SubscriptionResponseModel(user, subscriptionInfo, license);
}
else if (!_globalSettings.SelfHosted)
{
var license = await _userService.GenerateLicenseAsync(user);
return new SubscriptionResponseModel(user, license);
}
else
{
return new SubscriptionResponseModel(user);
}
}
[HttpPost("payment")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task PostPayment([FromBody] PaymentRequestModel model)
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
await _userService.ReplacePaymentMethodAsync(user, model.PaymentToken, model.PaymentMethodType.Value,
new TaxInfo
{
BillingAddressLine1 = model.Line1,
BillingAddressLine2 = model.Line2,
BillingAddressCity = model.City,
BillingAddressState = model.State,
BillingAddressCountry = model.Country,
BillingAddressPostalCode = model.PostalCode,
TaxIdNumber = model.TaxId
});
}
[HttpPost("storage")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<PaymentResponseModel> PostStorage([FromBody] StorageRequestModel model)
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var result = await _userService.AdjustStorageAsync(user, model.StorageGbAdjustment.Value);
return new PaymentResponseModel
{
Success = true,
PaymentIntentClientSecret = result
};
}
[HttpPost("license")]
[SelfHosted(SelfHostedOnly = true)]
public async Task PostLicense(LicenseRequestModel model)
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var license = await ApiHelpers.ReadJsonFileFromBody<UserLicense>(HttpContext, model.License);
if (license == null)
{
throw new BadRequestException("Invalid license");
}
await _userService.UpdateLicenseAsync(user, license);
}
[HttpPost("cancel")]
public async Task PostCancel([FromBody] SubscriptionCancellationRequestModel request)
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
await _subscriberService.CancelSubscription(user,
new OffboardingSurveyResponse
{
UserId = user.Id,
Reason = request.Reason,
Feedback = request.Feedback
},
user.IsExpired());
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(
ReferenceEventType.CancelSubscription,
user,
_currentContext)
{
EndOfPeriod = user.IsExpired()
});
}
[HttpPost("reinstate-premium")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task PostReinstate()
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
await _userService.ReinstatePremiumAsync(user);
}
[HttpGet("tax")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<TaxInfoResponseModel> GetTaxInfo()
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var taxInfo = await _paymentService.GetTaxInfoAsync(user);
return new TaxInfoResponseModel(taxInfo);
}
[HttpPut("tax")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task PutTaxInfo([FromBody] TaxInfoUpdateRequestModel model)
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var taxInfo = new TaxInfo
{
BillingAddressPostalCode = model.PostalCode,
BillingAddressCountry = model.Country,
};
await _paymentService.SaveTaxInfoAsync(user, taxInfo);
}
[HttpDelete("sso/{organizationId}")] [HttpDelete("sso/{organizationId}")]
public async Task DeleteSsoUser(string organizationId) public async Task DeleteSsoUser(string organizationId)
{ {

View File

@ -167,7 +167,7 @@ public class EmergencyAccessController : Controller
{ {
var user = await _userService.GetUserByPrincipalAsync(User); var user = await _userService.GetUserByPrincipalAsync(User);
var viewResult = await _emergencyAccessService.ViewAsync(id, user); var viewResult = await _emergencyAccessService.ViewAsync(id, user);
return new EmergencyAccessViewResponseModel(_globalSettings, viewResult.EmergencyAccess, viewResult.Ciphers); return new EmergencyAccessViewResponseModel(_globalSettings, viewResult.EmergencyAccess, viewResult.Ciphers, user);
} }
[HttpGet("{id}/{cipherId}/attachment/{attachmentId}")] [HttpGet("{id}/{cipherId}/attachment/{attachmentId}")]

View File

@ -23,6 +23,7 @@ public class AuthRequestResponseModel : ResponseModel
RequestDeviceType = authRequest.RequestDeviceType.GetType().GetMember(authRequest.RequestDeviceType.ToString()) RequestDeviceType = authRequest.RequestDeviceType.GetType().GetMember(authRequest.RequestDeviceType.ToString())
.FirstOrDefault()?.GetCustomAttribute<DisplayAttribute>()?.GetName(); .FirstOrDefault()?.GetCustomAttribute<DisplayAttribute>()?.GetName();
RequestIpAddress = authRequest.RequestIpAddress; RequestIpAddress = authRequest.RequestIpAddress;
RequestCountryName = authRequest.RequestCountryName;
Key = authRequest.Key; Key = authRequest.Key;
MasterPasswordHash = authRequest.MasterPasswordHash; MasterPasswordHash = authRequest.MasterPasswordHash;
CreationDate = authRequest.CreationDate; CreationDate = authRequest.CreationDate;
@ -37,6 +38,7 @@ public class AuthRequestResponseModel : ResponseModel
public DeviceType RequestDeviceTypeValue { get; set; } public DeviceType RequestDeviceTypeValue { get; set; }
public string RequestDeviceType { get; set; } public string RequestDeviceType { get; set; }
public string RequestIpAddress { get; set; } public string RequestIpAddress { get; set; }
public string RequestCountryName { get; set; }
public string Key { get; set; } public string Key { get; set; }
public string MasterPasswordHash { get; set; } public string MasterPasswordHash { get; set; }
public DateTime CreationDate { get; set; } public DateTime CreationDate { get; set; }

View File

@ -116,11 +116,17 @@ public class EmergencyAccessViewResponseModel : ResponseModel
public EmergencyAccessViewResponseModel( public EmergencyAccessViewResponseModel(
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
EmergencyAccess emergencyAccess, EmergencyAccess emergencyAccess,
IEnumerable<CipherDetails> ciphers) IEnumerable<CipherDetails> ciphers,
User user)
: base("emergencyAccessView") : base("emergencyAccessView")
{ {
KeyEncrypted = emergencyAccess.KeyEncrypted; KeyEncrypted = emergencyAccess.KeyEncrypted;
Ciphers = ciphers.Select(c => new CipherResponseModel(c, globalSettings)); Ciphers = ciphers.Select(cipher =>
new CipherResponseModel(
cipher,
user,
organizationAbilities: null, // Emergency access only retrieves personal ciphers so organizationAbilities is not needed
globalSettings));
} }
public string KeyEncrypted { get; set; } public string KeyEncrypted { get; set; }

View File

@ -0,0 +1,237 @@
#nullable enable
using Bit.Api.Models.Request;
using Bit.Api.Models.Request.Accounts;
using Bit.Api.Models.Response;
using Bit.Api.Utilities;
using Bit.Core.Billing.Models;
using Bit.Core.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Tools.Enums;
using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Services;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
namespace Bit.Api.Billing.Controllers;
[Route("accounts")]
[Authorize("Application")]
public class AccountsController(
IUserService userService) : Controller
{
[HttpPost("premium")]
public async Task<PaymentResponseModel> PostPremiumAsync(
PremiumRequestModel model,
[FromServices] GlobalSettings globalSettings)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var valid = model.Validate(globalSettings);
UserLicense? license = null;
if (valid && globalSettings.SelfHosted)
{
license = await ApiHelpers.ReadJsonFileFromBody<UserLicense>(HttpContext, model.License);
}
if (!valid && !globalSettings.SelfHosted && string.IsNullOrWhiteSpace(model.Country))
{
throw new BadRequestException("Country is required.");
}
if (!valid || (globalSettings.SelfHosted && license == null))
{
throw new BadRequestException("Invalid license.");
}
var result = await userService.SignUpPremiumAsync(user, model.PaymentToken,
model.PaymentMethodType!.Value, model.AdditionalStorageGb.GetValueOrDefault(0), license,
new TaxInfo { BillingAddressCountry = model.Country, BillingAddressPostalCode = model.PostalCode });
var userTwoFactorEnabled = await userService.TwoFactorIsEnabledAsync(user);
var userHasPremiumFromOrganization = await userService.HasPremiumFromOrganization(user);
var organizationIdsManagingActiveUser = await GetOrganizationIdsManagingUserAsync(user.Id);
var profile = new ProfileResponseModel(user, null, null, null, userTwoFactorEnabled,
userHasPremiumFromOrganization, organizationIdsManagingActiveUser);
return new PaymentResponseModel
{
UserProfile = profile,
PaymentIntentClientSecret = result.Item2,
Success = result.Item1
};
}
[HttpGet("subscription")]
public async Task<SubscriptionResponseModel> GetSubscriptionAsync(
[FromServices] GlobalSettings globalSettings,
[FromServices] IPaymentService paymentService)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
if (!globalSettings.SelfHosted && user.Gateway != null)
{
var subscriptionInfo = await paymentService.GetSubscriptionAsync(user);
var license = await userService.GenerateLicenseAsync(user, subscriptionInfo);
return new SubscriptionResponseModel(user, subscriptionInfo, license);
}
else if (!globalSettings.SelfHosted)
{
var license = await userService.GenerateLicenseAsync(user);
return new SubscriptionResponseModel(user, license);
}
else
{
return new SubscriptionResponseModel(user);
}
}
[HttpPost("payment")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task PostPaymentAsync([FromBody] PaymentRequestModel model)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
await userService.ReplacePaymentMethodAsync(user, model.PaymentToken, model.PaymentMethodType!.Value,
new TaxInfo
{
BillingAddressLine1 = model.Line1,
BillingAddressLine2 = model.Line2,
BillingAddressCity = model.City,
BillingAddressState = model.State,
BillingAddressCountry = model.Country,
BillingAddressPostalCode = model.PostalCode,
TaxIdNumber = model.TaxId
});
}
[HttpPost("storage")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<PaymentResponseModel> PostStorageAsync([FromBody] StorageRequestModel model)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var result = await userService.AdjustStorageAsync(user, model.StorageGbAdjustment!.Value);
return new PaymentResponseModel { Success = true, PaymentIntentClientSecret = result };
}
[HttpPost("license")]
[SelfHosted(SelfHostedOnly = true)]
public async Task PostLicenseAsync(LicenseRequestModel model)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var license = await ApiHelpers.ReadJsonFileFromBody<UserLicense>(HttpContext, model.License);
if (license == null)
{
throw new BadRequestException("Invalid license");
}
await userService.UpdateLicenseAsync(user, license);
}
[HttpPost("cancel")]
public async Task PostCancelAsync(
[FromBody] SubscriptionCancellationRequestModel request,
[FromServices] ICurrentContext currentContext,
[FromServices] IReferenceEventService referenceEventService,
[FromServices] ISubscriberService subscriberService)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
await subscriberService.CancelSubscription(user,
new OffboardingSurveyResponse { UserId = user.Id, Reason = request.Reason, Feedback = request.Feedback },
user.IsExpired());
await referenceEventService.RaiseEventAsync(new ReferenceEvent(
ReferenceEventType.CancelSubscription,
user,
currentContext)
{ EndOfPeriod = user.IsExpired() });
}
[HttpPost("reinstate-premium")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task PostReinstateAsync()
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
await userService.ReinstatePremiumAsync(user);
}
[HttpGet("tax")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<TaxInfoResponseModel> GetTaxInfoAsync(
[FromServices] IPaymentService paymentService)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var taxInfo = await paymentService.GetTaxInfoAsync(user);
return new TaxInfoResponseModel(taxInfo);
}
[HttpPut("tax")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task PutTaxInfoAsync(
[FromBody] TaxInfoUpdateRequestModel model,
[FromServices] IPaymentService paymentService)
{
var user = await userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
var taxInfo = new TaxInfo
{
BillingAddressPostalCode = model.PostalCode,
BillingAddressCountry = model.Country,
};
await paymentService.SaveTaxInfoAsync(user, taxInfo);
}
private async Task<IEnumerable<Guid>> GetOrganizationIdsManagingUserAsync(Guid userId)
{
var organizationManagingUser = await userService.GetOrganizationsManagingUserAsync(userId);
return organizationManagingUser.Select(o => o.Id);
}
}

View File

@ -1,5 +1,6 @@
using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses; using Bit.Api.Billing.Models.Responses;
using Bit.Core;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
@ -20,6 +21,7 @@ namespace Bit.Api.Billing.Controllers;
[Authorize("Application")] [Authorize("Application")]
public class ProviderBillingController( public class ProviderBillingController(
ICurrentContext currentContext, ICurrentContext currentContext,
IFeatureService featureService,
ILogger<BaseProviderController> logger, ILogger<BaseProviderController> logger,
IPricingClient pricingClient, IPricingClient pricingClient,
IProviderBillingService providerBillingService, IProviderBillingService providerBillingService,
@ -71,6 +73,65 @@ public class ProviderBillingController(
"text/csv"); "text/csv");
} }
[HttpPut("payment-method")]
public async Task<IResult> UpdatePaymentMethodAsync(
[FromRoute] Guid providerId,
[FromBody] UpdatePaymentMethodRequestBody requestBody)
{
var allowProviderPaymentMethod = featureService.IsEnabled(FeatureFlagKeys.PM18794_ProviderPaymentMethod);
if (!allowProviderPaymentMethod)
{
return TypedResults.NotFound();
}
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var tokenizedPaymentSource = requestBody.PaymentSource.ToDomain();
var taxInformation = requestBody.TaxInformation.ToDomain();
await providerBillingService.UpdatePaymentMethod(
provider,
tokenizedPaymentSource,
taxInformation);
return TypedResults.Ok();
}
[HttpPost("payment-method/verify-bank-account")]
public async Task<IResult> VerifyBankAccountAsync(
[FromRoute] Guid providerId,
[FromBody] VerifyBankAccountRequestBody requestBody)
{
var allowProviderPaymentMethod = featureService.IsEnabled(FeatureFlagKeys.PM18794_ProviderPaymentMethod);
if (!allowProviderPaymentMethod)
{
return TypedResults.NotFound();
}
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
if (requestBody.DescriptorCode.Length != 6 || !requestBody.DescriptorCode.StartsWith("SM"))
{
return Error.BadRequest("Statement descriptor should be a 6-character value that starts with 'SM'");
}
await subscriberService.VerifyBankAccount(provider, requestBody.DescriptorCode);
return TypedResults.Ok();
}
[HttpGet("subscription")] [HttpGet("subscription")]
public async Task<IResult> GetSubscriptionAsync([FromRoute] Guid providerId) public async Task<IResult> GetSubscriptionAsync([FromRoute] Guid providerId)
{ {
@ -102,12 +163,32 @@ public class ProviderBillingController(
var subscriptionSuspension = await GetSubscriptionSuspensionAsync(stripeAdapter, subscription); var subscriptionSuspension = await GetSubscriptionSuspensionAsync(stripeAdapter, subscription);
var paymentSource = await subscriberService.GetPaymentSource(provider);
var response = ProviderSubscriptionResponse.From( var response = ProviderSubscriptionResponse.From(
subscription, subscription,
configuredProviderPlans, configuredProviderPlans,
taxInformation, taxInformation,
subscriptionSuspension, subscriptionSuspension,
provider); provider,
paymentSource);
return TypedResults.Ok(response);
}
[HttpGet("tax-information")]
public async Task<IResult> GetTaxInformationAsync([FromRoute] Guid providerId)
{
var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId);
if (provider == null)
{
return result;
}
var taxInformation = await subscriberService.GetTaxInformation(provider);
var response = TaxInformationResponse.From(taxInformation);
return TypedResults.Ok(response); return TypedResults.Ok(response);
} }

View File

@ -16,7 +16,8 @@ public record ProviderSubscriptionResponse(
TaxInformation TaxInformation, TaxInformation TaxInformation,
DateTime? CancelAt, DateTime? CancelAt,
SubscriptionSuspension Suspension, SubscriptionSuspension Suspension,
ProviderType ProviderType) ProviderType ProviderType,
PaymentSource PaymentSource)
{ {
private const string _annualCadence = "Annual"; private const string _annualCadence = "Annual";
private const string _monthlyCadence = "Monthly"; private const string _monthlyCadence = "Monthly";
@ -26,7 +27,8 @@ public record ProviderSubscriptionResponse(
ICollection<ConfiguredProviderPlan> providerPlans, ICollection<ConfiguredProviderPlan> providerPlans,
TaxInformation taxInformation, TaxInformation taxInformation,
SubscriptionSuspension subscriptionSuspension, SubscriptionSuspension subscriptionSuspension,
Provider provider) Provider provider,
PaymentSource paymentSource)
{ {
var providerPlanResponses = providerPlans var providerPlanResponses = providerPlans
.Select(providerPlan => .Select(providerPlan =>
@ -57,7 +59,8 @@ public record ProviderSubscriptionResponse(
taxInformation, taxInformation,
subscription.CancelAt, subscription.CancelAt,
subscriptionSuspension, subscriptionSuspension,
provider.Type); provider.Type,
paymentSource);
} }
} }

View File

@ -43,8 +43,9 @@ public class PushController : Controller
public async Task RegisterAsync([FromBody] PushRegistrationRequestModel model) public async Task RegisterAsync([FromBody] PushRegistrationRequestModel model)
{ {
CheckUsage(); CheckUsage();
await _pushRegistrationService.CreateOrUpdateRegistrationAsync(new PushRegistrationData(model.PushToken), Prefix(model.DeviceId), await _pushRegistrationService.CreateOrUpdateRegistrationAsync(new PushRegistrationData(model.PushToken),
Prefix(model.UserId), Prefix(model.Identifier), model.Type, model.OrganizationIds.Select(Prefix), model.InstallationId); Prefix(model.DeviceId), Prefix(model.UserId), Prefix(model.Identifier), model.Type,
model.OrganizationIds?.Select(Prefix) ?? [], model.InstallationId);
} }
[HttpPost("delete")] [HttpPost("delete")]

View File

@ -56,7 +56,7 @@ public class ImportCiphersController : Controller
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var folders = model.Folders.Select(f => f.ToFolder(userId)).ToList(); var folders = model.Folders.Select(f => f.ToFolder(userId)).ToList();
var ciphers = model.Ciphers.Select(c => c.ToCipherDetails(userId, false)).ToList(); var ciphers = model.Ciphers.Select(c => c.ToCipherDetails(userId, false)).ToList();
await _importCiphersCommand.ImportIntoIndividualVaultAsync(folders, ciphers, model.FolderRelationships); await _importCiphersCommand.ImportIntoIndividualVaultAsync(folders, ciphers, model.FolderRelationships, userId);
} }
[HttpPost("import-organization")] [HttpPost("import-organization")]

View File

@ -12,7 +12,7 @@ public static class CommandResultExtensions
NoRecordFoundFailure<T> failure => new ObjectResult(failure.ErrorMessages) { StatusCode = StatusCodes.Status404NotFound }, NoRecordFoundFailure<T> failure => new ObjectResult(failure.ErrorMessages) { StatusCode = StatusCodes.Status404NotFound },
BadRequestFailure<T> failure => new ObjectResult(failure.ErrorMessages) { StatusCode = StatusCodes.Status400BadRequest }, BadRequestFailure<T> failure => new ObjectResult(failure.ErrorMessages) { StatusCode = StatusCodes.Status400BadRequest },
Failure<T> failure => new ObjectResult(failure.ErrorMessages) { StatusCode = StatusCodes.Status400BadRequest }, Failure<T> failure => new ObjectResult(failure.ErrorMessages) { StatusCode = StatusCodes.Status400BadRequest },
Success<T> success => new ObjectResult(success.Data) { StatusCode = StatusCodes.Status200OK }, Success<T> success => new ObjectResult(success.Value) { StatusCode = StatusCodes.Status200OK },
_ => throw new InvalidOperationException($"Unhandled commandResult type: {commandResult.GetType().Name}") _ => throw new InvalidOperationException($"Unhandled commandResult type: {commandResult.GetType().Name}")
}; };
} }

View File

@ -79,14 +79,16 @@ public class CiphersController : Controller
[HttpGet("{id}")] [HttpGet("{id}")]
public async Task<CipherResponseModel> Get(Guid id) public async Task<CipherResponseModel> Get(Guid id)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = await GetByIdAsync(id, userId); var cipher = await GetByIdAsync(id, user.Id);
if (cipher == null) if (cipher == null)
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
return new CipherResponseModel(cipher, _globalSettings); var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
return new CipherResponseModel(cipher, user, organizationAbilities, _globalSettings);
} }
[HttpGet("{id}/admin")] [HttpGet("{id}/admin")]
@ -109,32 +111,37 @@ public class CiphersController : Controller
[HttpGet("{id}/details")] [HttpGet("{id}/details")]
public async Task<CipherDetailsResponseModel> GetDetails(Guid id) public async Task<CipherDetailsResponseModel> GetDetails(Guid id)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = await GetByIdAsync(id, userId); var cipher = await GetByIdAsync(id, user.Id);
if (cipher == null) if (cipher == null)
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id); var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
return new CipherDetailsResponseModel(cipher, _globalSettings, collectionCiphers); var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(user.Id, id);
return new CipherDetailsResponseModel(cipher, user, organizationAbilities, _globalSettings, collectionCiphers);
} }
[HttpGet("")] [HttpGet("")]
public async Task<ListResponseModel<CipherDetailsResponseModel>> Get() public async Task<ListResponseModel<CipherDetailsResponseModel>> Get()
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var hasOrgs = _currentContext.Organizations?.Any() ?? false; var hasOrgs = _currentContext.Organizations.Count != 0;
// TODO: Use hasOrgs proper for cipher listing here? // TODO: Use hasOrgs proper for cipher listing here?
var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, withOrganizations: true || hasOrgs); var ciphers = await _cipherRepository.GetManyByUserIdAsync(user.Id, withOrganizations: true);
Dictionary<Guid, IGrouping<Guid, CollectionCipher>> collectionCiphersGroupDict = null; Dictionary<Guid, IGrouping<Guid, CollectionCipher>> collectionCiphersGroupDict = null;
if (hasOrgs) if (hasOrgs)
{ {
var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdAsync(userId); var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdAsync(user.Id);
collectionCiphersGroupDict = collectionCiphers.GroupBy(c => c.CipherId).ToDictionary(s => s.Key); collectionCiphersGroupDict = collectionCiphers.GroupBy(c => c.CipherId).ToDictionary(s => s.Key);
} }
var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var responses = ciphers.Select(c => new CipherDetailsResponseModel(c, _globalSettings, var responses = ciphers.Select(cipher => new CipherDetailsResponseModel(
cipher,
user,
organizationAbilities,
_globalSettings,
collectionCiphersGroupDict)).ToList(); collectionCiphersGroupDict)).ToList();
return new ListResponseModel<CipherDetailsResponseModel>(responses); return new ListResponseModel<CipherDetailsResponseModel>(responses);
} }
@ -142,30 +149,38 @@ public class CiphersController : Controller
[HttpPost("")] [HttpPost("")]
public async Task<CipherResponseModel> Post([FromBody] CipherRequestModel model) public async Task<CipherResponseModel> Post([FromBody] CipherRequestModel model)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = model.ToCipherDetails(userId); var cipher = model.ToCipherDetails(user.Id);
if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value))
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
await _cipherService.SaveDetailsAsync(cipher, userId, model.LastKnownRevisionDate, null, cipher.OrganizationId.HasValue); await _cipherService.SaveDetailsAsync(cipher, user.Id, model.LastKnownRevisionDate, null, cipher.OrganizationId.HasValue);
var response = new CipherResponseModel(cipher, _globalSettings); var response = new CipherResponseModel(
cipher,
user,
await _applicationCacheService.GetOrganizationAbilitiesAsync(),
_globalSettings);
return response; return response;
} }
[HttpPost("create")] [HttpPost("create")]
public async Task<CipherResponseModel> PostCreate([FromBody] CipherCreateRequestModel model) public async Task<CipherResponseModel> PostCreate([FromBody] CipherCreateRequestModel model)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = model.Cipher.ToCipherDetails(userId); var cipher = model.Cipher.ToCipherDetails(user.Id);
if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value))
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
await _cipherService.SaveDetailsAsync(cipher, userId, model.Cipher.LastKnownRevisionDate, model.CollectionIds, cipher.OrganizationId.HasValue); await _cipherService.SaveDetailsAsync(cipher, user.Id, model.Cipher.LastKnownRevisionDate, model.CollectionIds, cipher.OrganizationId.HasValue);
var response = new CipherResponseModel(cipher, _globalSettings); var response = new CipherResponseModel(
cipher,
user,
await _applicationCacheService.GetOrganizationAbilitiesAsync(),
_globalSettings);
return response; return response;
} }
@ -191,8 +206,8 @@ public class CiphersController : Controller
[HttpPost("{id}")] [HttpPost("{id}")]
public async Task<CipherResponseModel> Put(Guid id, [FromBody] CipherRequestModel model) public async Task<CipherResponseModel> Put(Guid id, [FromBody] CipherRequestModel model)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = await GetByIdAsync(id, userId); var cipher = await GetByIdAsync(id, user.Id);
if (cipher == null) if (cipher == null)
{ {
throw new NotFoundException(); throw new NotFoundException();
@ -200,7 +215,7 @@ public class CiphersController : Controller
ValidateClientVersionForFido2CredentialSupport(cipher); ValidateClientVersionForFido2CredentialSupport(cipher);
var collectionIds = (await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id)).Select(c => c.CollectionId).ToList(); var collectionIds = (await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(user.Id, id)).Select(c => c.CollectionId).ToList();
var modelOrgId = string.IsNullOrWhiteSpace(model.OrganizationId) ? var modelOrgId = string.IsNullOrWhiteSpace(model.OrganizationId) ?
(Guid?)null : new Guid(model.OrganizationId); (Guid?)null : new Guid(model.OrganizationId);
if (cipher.OrganizationId != modelOrgId) if (cipher.OrganizationId != modelOrgId)
@ -209,9 +224,13 @@ public class CiphersController : Controller
"then try again."); "then try again.");
} }
await _cipherService.SaveDetailsAsync(model.ToCipherDetails(cipher), userId, model.LastKnownRevisionDate, collectionIds); await _cipherService.SaveDetailsAsync(model.ToCipherDetails(cipher), user.Id, model.LastKnownRevisionDate, collectionIds);
var response = new CipherResponseModel(cipher, _globalSettings); var response = new CipherResponseModel(
cipher,
user,
await _applicationCacheService.GetOrganizationAbilitiesAsync(),
_globalSettings);
return response; return response;
} }
@ -278,7 +297,14 @@ public class CiphersController : Controller
})); }));
} }
var responses = ciphers.Select(c => new CipherDetailsResponseModel(c, _globalSettings)); var user = await _userService.GetUserByPrincipalAsync(User);
var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var responses = ciphers.Select(cipher =>
new CipherDetailsResponseModel(
cipher,
user,
organizationAbilities,
_globalSettings));
return new ListResponseModel<CipherDetailsResponseModel>(responses); return new ListResponseModel<CipherDetailsResponseModel>(responses);
} }
@ -572,12 +598,16 @@ public class CiphersController : Controller
[HttpPost("{id}/partial")] [HttpPost("{id}/partial")]
public async Task<CipherResponseModel> PutPartial(Guid id, [FromBody] CipherPartialRequestModel model) public async Task<CipherResponseModel> PutPartial(Guid id, [FromBody] CipherPartialRequestModel model)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var folderId = string.IsNullOrWhiteSpace(model.FolderId) ? null : (Guid?)new Guid(model.FolderId); var folderId = string.IsNullOrWhiteSpace(model.FolderId) ? null : (Guid?)new Guid(model.FolderId);
await _cipherRepository.UpdatePartialAsync(id, userId, folderId, model.Favorite); await _cipherRepository.UpdatePartialAsync(id, user.Id, folderId, model.Favorite);
var cipher = await GetByIdAsync(id, userId); var cipher = await GetByIdAsync(id, user.Id);
var response = new CipherResponseModel(cipher, _globalSettings); var response = new CipherResponseModel(
cipher,
user,
await _applicationCacheService.GetOrganizationAbilitiesAsync(),
_globalSettings);
return response; return response;
} }
@ -585,9 +615,9 @@ public class CiphersController : Controller
[HttpPost("{id}/share")] [HttpPost("{id}/share")]
public async Task<CipherResponseModel> PutShare(Guid id, [FromBody] CipherShareRequestModel model) public async Task<CipherResponseModel> PutShare(Guid id, [FromBody] CipherShareRequestModel model)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = await _cipherRepository.GetByIdAsync(id); var cipher = await _cipherRepository.GetByIdAsync(id);
if (cipher == null || cipher.UserId != userId || if (cipher == null || cipher.UserId != user.Id ||
!await _currentContext.OrganizationUser(new Guid(model.Cipher.OrganizationId))) !await _currentContext.OrganizationUser(new Guid(model.Cipher.OrganizationId)))
{ {
throw new NotFoundException(); throw new NotFoundException();
@ -597,10 +627,14 @@ public class CiphersController : Controller
var original = cipher.Clone(); var original = cipher.Clone();
await _cipherService.ShareAsync(original, model.Cipher.ToCipher(cipher), new Guid(model.Cipher.OrganizationId), await _cipherService.ShareAsync(original, model.Cipher.ToCipher(cipher), new Guid(model.Cipher.OrganizationId),
model.CollectionIds.Select(c => new Guid(c)), userId, model.Cipher.LastKnownRevisionDate); model.CollectionIds.Select(c => new Guid(c)), user.Id, model.Cipher.LastKnownRevisionDate);
var sharedCipher = await GetByIdAsync(id, userId); var sharedCipher = await GetByIdAsync(id, user.Id);
var response = new CipherResponseModel(sharedCipher, _globalSettings); var response = new CipherResponseModel(
sharedCipher,
user,
await _applicationCacheService.GetOrganizationAbilitiesAsync(),
_globalSettings);
return response; return response;
} }
@ -608,8 +642,8 @@ public class CiphersController : Controller
[HttpPost("{id}/collections")] [HttpPost("{id}/collections")]
public async Task<CipherDetailsResponseModel> PutCollections(Guid id, [FromBody] CipherCollectionsRequestModel model) public async Task<CipherDetailsResponseModel> PutCollections(Guid id, [FromBody] CipherCollectionsRequestModel model)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = await GetByIdAsync(id, userId); var cipher = await GetByIdAsync(id, user.Id);
if (cipher == null || !cipher.OrganizationId.HasValue || if (cipher == null || !cipher.OrganizationId.HasValue ||
!await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) !await _currentContext.OrganizationUser(cipher.OrganizationId.Value))
{ {
@ -617,20 +651,25 @@ public class CiphersController : Controller
} }
await _cipherService.SaveCollectionsAsync(cipher, await _cipherService.SaveCollectionsAsync(cipher,
model.CollectionIds.Select(c => new Guid(c)), userId, false); model.CollectionIds.Select(c => new Guid(c)), user.Id, false);
var updatedCipher = await GetByIdAsync(id, userId); var updatedCipher = await GetByIdAsync(id, user.Id);
var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id); var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(user.Id, id);
return new CipherDetailsResponseModel(updatedCipher, _globalSettings, collectionCiphers); return new CipherDetailsResponseModel(
updatedCipher,
user,
await _applicationCacheService.GetOrganizationAbilitiesAsync(),
_globalSettings,
collectionCiphers);
} }
[HttpPut("{id}/collections_v2")] [HttpPut("{id}/collections_v2")]
[HttpPost("{id}/collections_v2")] [HttpPost("{id}/collections_v2")]
public async Task<OptionalCipherDetailsResponseModel> PutCollections_vNext(Guid id, [FromBody] CipherCollectionsRequestModel model) public async Task<OptionalCipherDetailsResponseModel> PutCollections_vNext(Guid id, [FromBody] CipherCollectionsRequestModel model)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = await GetByIdAsync(id, userId); var cipher = await GetByIdAsync(id, user.Id);
if (cipher == null || !cipher.OrganizationId.HasValue || if (cipher == null || !cipher.OrganizationId.HasValue ||
!await _currentContext.OrganizationUser(cipher.OrganizationId.Value) || !cipher.ViewPassword) !await _currentContext.OrganizationUser(cipher.OrganizationId.Value) || !cipher.ViewPassword)
{ {
@ -638,10 +677,10 @@ public class CiphersController : Controller
} }
await _cipherService.SaveCollectionsAsync(cipher, await _cipherService.SaveCollectionsAsync(cipher,
model.CollectionIds.Select(c => new Guid(c)), userId, false); model.CollectionIds.Select(c => new Guid(c)), user.Id, false);
var updatedCipher = await GetByIdAsync(id, userId); var updatedCipher = await GetByIdAsync(id, user.Id);
var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id); var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(user.Id, id);
// If a user removes the last Can Manage access of a cipher, the "updatedCipher" will return null // If a user removes the last Can Manage access of a cipher, the "updatedCipher" will return null
// We will be returning an "Unavailable" property so the client knows the user can no longer access this // We will be returning an "Unavailable" property so the client knows the user can no longer access this
var response = new OptionalCipherDetailsResponseModel() var response = new OptionalCipherDetailsResponseModel()
@ -649,7 +688,12 @@ public class CiphersController : Controller
Unavailable = updatedCipher is null, Unavailable = updatedCipher is null,
Cipher = updatedCipher is null Cipher = updatedCipher is null
? null ? null
: new CipherDetailsResponseModel(updatedCipher, _globalSettings, collectionCiphers) : new CipherDetailsResponseModel(
updatedCipher,
user,
await _applicationCacheService.GetOrganizationAbilitiesAsync(),
_globalSettings,
collectionCiphers)
}; };
return response; return response;
} }
@ -839,15 +883,19 @@ public class CiphersController : Controller
[HttpPut("{id}/restore")] [HttpPut("{id}/restore")]
public async Task<CipherResponseModel> PutRestore(Guid id) public async Task<CipherResponseModel> PutRestore(Guid id)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = await GetByIdAsync(id, userId); var cipher = await GetByIdAsync(id, user.Id);
if (cipher == null) if (cipher == null)
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
await _cipherService.RestoreAsync(cipher, userId); await _cipherService.RestoreAsync(cipher, user.Id);
return new CipherResponseModel(cipher, _globalSettings); return new CipherResponseModel(
cipher,
user,
await _applicationCacheService.GetOrganizationAbilitiesAsync(),
_globalSettings);
} }
[HttpPut("{id}/restore-admin")] [HttpPut("{id}/restore-admin")]
@ -996,10 +1044,10 @@ public class CiphersController : Controller
[HttpPost("{id}/attachment/v2")] [HttpPost("{id}/attachment/v2")]
public async Task<AttachmentUploadDataResponseModel> PostAttachment(Guid id, [FromBody] AttachmentRequestModel request) public async Task<AttachmentUploadDataResponseModel> PostAttachment(Guid id, [FromBody] AttachmentRequestModel request)
{ {
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = request.AdminRequest ? var cipher = request.AdminRequest ?
await _cipherRepository.GetOrganizationDetailsByIdAsync(id) : await _cipherRepository.GetOrganizationDetailsByIdAsync(id) :
await GetByIdAsync(id, userId); await GetByIdAsync(id, user.Id);
if (cipher == null || (request.AdminRequest && (!cipher.OrganizationId.HasValue || if (cipher == null || (request.AdminRequest && (!cipher.OrganizationId.HasValue ||
!await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })))) !await CanEditCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id }))))
@ -1013,13 +1061,17 @@ public class CiphersController : Controller
} }
var (attachmentId, uploadUrl) = await _cipherService.CreateAttachmentForDelayedUploadAsync(cipher, var (attachmentId, uploadUrl) = await _cipherService.CreateAttachmentForDelayedUploadAsync(cipher,
request.Key, request.FileName, request.FileSize, request.AdminRequest, userId); request.Key, request.FileName, request.FileSize, request.AdminRequest, user.Id);
return new AttachmentUploadDataResponseModel return new AttachmentUploadDataResponseModel
{ {
AttachmentId = attachmentId, AttachmentId = attachmentId,
Url = uploadUrl, Url = uploadUrl,
FileUploadType = _attachmentStorageService.FileUploadType, FileUploadType = _attachmentStorageService.FileUploadType,
CipherResponse = request.AdminRequest ? null : new CipherResponseModel((CipherDetails)cipher, _globalSettings), CipherResponse = request.AdminRequest ? null : new CipherResponseModel(
(CipherDetails)cipher,
user,
await _applicationCacheService.GetOrganizationAbilitiesAsync(),
_globalSettings),
CipherMiniResponse = request.AdminRequest ? new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp) : null, CipherMiniResponse = request.AdminRequest ? new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp) : null,
}; };
} }
@ -1077,8 +1129,8 @@ public class CiphersController : Controller
{ {
ValidateAttachment(); ValidateAttachment();
var userId = _userService.GetProperUserId(User).Value; var user = await _userService.GetUserByPrincipalAsync(User);
var cipher = await GetByIdAsync(id, userId); var cipher = await GetByIdAsync(id, user.Id);
if (cipher == null) if (cipher == null)
{ {
throw new NotFoundException(); throw new NotFoundException();
@ -1087,10 +1139,14 @@ public class CiphersController : Controller
await Request.GetFileAsync(async (stream, fileName, key) => await Request.GetFileAsync(async (stream, fileName, key) =>
{ {
await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key,
Request.ContentLength.GetValueOrDefault(0), userId); Request.ContentLength.GetValueOrDefault(0), user.Id);
}); });
return new CipherResponseModel(cipher, _globalSettings); return new CipherResponseModel(
cipher,
user,
await _applicationCacheService.GetOrganizationAbilitiesAsync(),
_globalSettings);
} }
[HttpPost("{id}/attachment-admin")] [HttpPost("{id}/attachment-admin")]

View File

@ -36,6 +36,7 @@ public class SyncController : Controller
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly Version _sshKeyCipherMinimumVersion = new(Constants.SSHKeyCipherMinimumVersion); private readonly Version _sshKeyCipherMinimumVersion = new(Constants.SSHKeyCipherMinimumVersion);
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
private readonly IApplicationCacheService _applicationCacheService;
public SyncController( public SyncController(
IUserService userService, IUserService userService,
@ -49,7 +50,8 @@ public class SyncController : Controller
ISendRepository sendRepository, ISendRepository sendRepository,
GlobalSettings globalSettings, GlobalSettings globalSettings,
ICurrentContext currentContext, ICurrentContext currentContext,
IFeatureService featureService) IFeatureService featureService,
IApplicationCacheService applicationCacheService)
{ {
_userService = userService; _userService = userService;
_folderRepository = folderRepository; _folderRepository = folderRepository;
@ -63,6 +65,7 @@ public class SyncController : Controller
_globalSettings = globalSettings; _globalSettings = globalSettings;
_currentContext = currentContext; _currentContext = currentContext;
_featureService = featureService; _featureService = featureService;
_applicationCacheService = applicationCacheService;
} }
[HttpGet("")] [HttpGet("")]
@ -104,7 +107,9 @@ public class SyncController : Controller
var organizationManagingActiveUser = await _userService.GetOrganizationsManagingUserAsync(user.Id); var organizationManagingActiveUser = await _userService.GetOrganizationsManagingUserAsync(user.Id);
var organizationIdsManagingActiveUser = organizationManagingActiveUser.Select(o => o.Id); var organizationIdsManagingActiveUser = organizationManagingActiveUser.Select(o => o.Id);
var response = new SyncResponseModel(_globalSettings, user, userTwoFactorEnabled, userHasPremiumFromOrganization, var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var response = new SyncResponseModel(_globalSettings, user, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationAbilities,
organizationIdsManagingActiveUser, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails, organizationIdsManagingActiveUser, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails,
folders, collections, ciphers, collectionCiphersGroupDict, excludeDomains, policies, sends); folders, collections, ciphers, collectionCiphersGroupDict, excludeDomains, policies, sends);
return response; return response;

View File

@ -0,0 +1,27 @@
using Bit.Core.Entities;
using Bit.Core.Models.Data.Organizations;
using Bit.Core.Vault.Authorization.Permissions;
using Bit.Core.Vault.Models.Data;
namespace Bit.Api.Vault.Models.Response;
public record CipherPermissionsResponseModel
{
public bool Delete { get; init; }
public bool Restore { get; init; }
public CipherPermissionsResponseModel(
User user,
CipherDetails cipherDetails,
IDictionary<Guid, OrganizationAbility> organizationAbilities)
{
OrganizationAbility organizationAbility = null;
if (cipherDetails.OrganizationId.HasValue && !organizationAbilities.TryGetValue(cipherDetails.OrganizationId.Value, out organizationAbility))
{
throw new Exception("OrganizationAbility not found for organization cipher.");
}
Delete = NormalCipherPermissions.CanDelete(user, cipherDetails, organizationAbility);
Restore = NormalCipherPermissions.CanRestore(user, cipherDetails, organizationAbility);
}
}

View File

@ -1,6 +1,7 @@
using System.Text.Json; using System.Text.Json;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Models.Api; using Bit.Core.Models.Api;
using Bit.Core.Models.Data.Organizations;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
using Bit.Core.Vault.Enums; using Bit.Core.Vault.Enums;
@ -96,26 +97,37 @@ public class CipherMiniResponseModel : ResponseModel
public class CipherResponseModel : CipherMiniResponseModel public class CipherResponseModel : CipherMiniResponseModel
{ {
public CipherResponseModel(CipherDetails cipher, IGlobalSettings globalSettings, string obj = "cipher") public CipherResponseModel(
CipherDetails cipher,
User user,
IDictionary<Guid, OrganizationAbility> organizationAbilities,
IGlobalSettings globalSettings,
string obj = "cipher")
: base(cipher, globalSettings, cipher.OrganizationUseTotp, obj) : base(cipher, globalSettings, cipher.OrganizationUseTotp, obj)
{ {
FolderId = cipher.FolderId; FolderId = cipher.FolderId;
Favorite = cipher.Favorite; Favorite = cipher.Favorite;
Edit = cipher.Edit; Edit = cipher.Edit;
ViewPassword = cipher.ViewPassword; ViewPassword = cipher.ViewPassword;
Permissions = new CipherPermissionsResponseModel(user, cipher, organizationAbilities);
} }
public Guid? FolderId { get; set; } public Guid? FolderId { get; set; }
public bool Favorite { get; set; } public bool Favorite { get; set; }
public bool Edit { get; set; } public bool Edit { get; set; }
public bool ViewPassword { get; set; } public bool ViewPassword { get; set; }
public CipherPermissionsResponseModel Permissions { get; set; }
} }
public class CipherDetailsResponseModel : CipherResponseModel public class CipherDetailsResponseModel : CipherResponseModel
{ {
public CipherDetailsResponseModel(CipherDetails cipher, GlobalSettings globalSettings, public CipherDetailsResponseModel(
CipherDetails cipher,
User user,
IDictionary<Guid, OrganizationAbility> organizationAbilities,
GlobalSettings globalSettings,
IDictionary<Guid, IGrouping<Guid, CollectionCipher>> collectionCiphers, string obj = "cipherDetails") IDictionary<Guid, IGrouping<Guid, CollectionCipher>> collectionCiphers, string obj = "cipherDetails")
: base(cipher, globalSettings, obj) : base(cipher, user, organizationAbilities, globalSettings, obj)
{ {
if (collectionCiphers?.ContainsKey(cipher.Id) ?? false) if (collectionCiphers?.ContainsKey(cipher.Id) ?? false)
{ {
@ -127,15 +139,24 @@ public class CipherDetailsResponseModel : CipherResponseModel
} }
} }
public CipherDetailsResponseModel(CipherDetails cipher, GlobalSettings globalSettings, public CipherDetailsResponseModel(
CipherDetails cipher,
User user,
IDictionary<Guid, OrganizationAbility> organizationAbilities,
GlobalSettings globalSettings,
IEnumerable<CollectionCipher> collectionCiphers, string obj = "cipherDetails") IEnumerable<CollectionCipher> collectionCiphers, string obj = "cipherDetails")
: base(cipher, globalSettings, obj) : base(cipher, user, organizationAbilities, globalSettings, obj)
{ {
CollectionIds = collectionCiphers?.Select(c => c.CollectionId) ?? new List<Guid>(); CollectionIds = collectionCiphers?.Select(c => c.CollectionId) ?? new List<Guid>();
} }
public CipherDetailsResponseModel(CipherDetailsWithCollections cipher, GlobalSettings globalSettings, string obj = "cipherDetails") public CipherDetailsResponseModel(
: base(cipher, globalSettings, obj) CipherDetailsWithCollections cipher,
User user,
IDictionary<Guid, OrganizationAbility> organizationAbilities,
GlobalSettings globalSettings,
string obj = "cipherDetails")
: base(cipher, user, organizationAbilities, globalSettings, obj)
{ {
CollectionIds = cipher.CollectionIds ?? new List<Guid>(); CollectionIds = cipher.CollectionIds ?? new List<Guid>();
} }

View File

@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Models.Data.Provider;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Models.Api; using Bit.Core.Models.Api;
using Bit.Core.Models.Data; using Bit.Core.Models.Data;
using Bit.Core.Models.Data.Organizations;
using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Models.Data.Organizations.OrganizationUsers;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Tools.Entities; using Bit.Core.Tools.Entities;
@ -21,6 +22,7 @@ public class SyncResponseModel : ResponseModel
User user, User user,
bool userTwoFactorEnabled, bool userTwoFactorEnabled,
bool userHasPremiumFromOrganization, bool userHasPremiumFromOrganization,
IDictionary<Guid, OrganizationAbility> organizationAbilities,
IEnumerable<Guid> organizationIdsManagingUser, IEnumerable<Guid> organizationIdsManagingUser,
IEnumerable<OrganizationUserOrganizationDetails> organizationUserDetails, IEnumerable<OrganizationUserOrganizationDetails> organizationUserDetails,
IEnumerable<ProviderUserProviderDetails> providerUserDetails, IEnumerable<ProviderUserProviderDetails> providerUserDetails,
@ -37,7 +39,13 @@ public class SyncResponseModel : ResponseModel
Profile = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, Profile = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails,
providerUserOrganizationDetails, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsManagingUser); providerUserOrganizationDetails, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsManagingUser);
Folders = folders.Select(f => new FolderResponseModel(f)); Folders = folders.Select(f => new FolderResponseModel(f));
Ciphers = ciphers.Select(c => new CipherDetailsResponseModel(c, globalSettings, collectionCiphersDict)); Ciphers = ciphers.Select(cipher =>
new CipherDetailsResponseModel(
cipher,
user,
organizationAbilities,
globalSettings,
collectionCiphersDict));
Collections = collections?.Select( Collections = collections?.Select(
c => new CollectionDetailsResponseModel(c)) ?? new List<CollectionDetailsResponseModel>(); c => new CollectionDetailsResponseModel(c)) ?? new List<CollectionDetailsResponseModel>();
Domains = excludeDomains ? null : new DomainsResponseModel(user, false); Domains = excludeDomains ? null : new DomainsResponseModel(user, false);

View File

@ -0,0 +1,3 @@
namespace Bit.Core.AdminConsole.Errors;
public record Error<T>(string Message, T ErroredValue);

View File

@ -0,0 +1,11 @@
namespace Bit.Core.AdminConsole.Errors;
public record InsufficientPermissionsError<T>(string Message, T ErroredValue) : Error<T>(Message, ErroredValue)
{
public const string Code = "Insufficient Permissions";
public InsufficientPermissionsError(T ErroredValue) : this(Code, ErroredValue)
{
}
}

View File

@ -0,0 +1,11 @@
namespace Bit.Core.AdminConsole.Errors;
public record RecordNotFoundError<T>(string Message, T ErroredValue) : Error<T>(Message, ErroredValue)
{
public const string Code = "Record Not Found";
public RecordNotFoundError(T ErroredValue) : this(Code, ErroredValue)
{
}
}

View File

@ -8,21 +8,25 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations;
public class PolicyRequirementQuery( public class PolicyRequirementQuery(
IPolicyRepository policyRepository, IPolicyRepository policyRepository,
IEnumerable<RequirementFactory<IPolicyRequirement>> factories) IEnumerable<IPolicyRequirementFactory<IPolicyRequirement>> factories)
: IPolicyRequirementQuery : IPolicyRequirementQuery
{ {
public async Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement public async Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement
{ {
var factory = factories.OfType<RequirementFactory<T>>().SingleOrDefault(); var factory = factories.OfType<IPolicyRequirementFactory<T>>().SingleOrDefault();
if (factory is null) if (factory is null)
{ {
throw new NotImplementedException("No Policy Requirement found for " + typeof(T)); throw new NotImplementedException("No Requirement Factory found for " + typeof(T));
} }
return factory(await GetPolicyDetails(userId)); var policyDetails = await GetPolicyDetails(userId);
var filteredPolicies = policyDetails
.Where(p => p.PolicyType == factory.PolicyType)
.Where(factory.Enforce);
var requirement = factory.Create(filteredPolicies);
return requirement;
} }
private Task<IEnumerable<PolicyDetails>> GetPolicyDetails(Guid userId) => private Task<IEnumerable<PolicyDetails>> GetPolicyDetails(Guid userId)
policyRepository.GetPolicyDetailsByUserId(userId); => policyRepository.GetPolicyDetailsByUserId(userId);
} }

View File

@ -0,0 +1,44 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.Enums;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
/// <summary>
/// A simple base implementation of <see cref="IPolicyRequirementFactory{T}"/> which will be suitable for most policies.
/// It provides sensible defaults to help teams to implement their own Policy Requirements.
/// </summary>
/// <typeparam name="T"></typeparam>
public abstract class BasePolicyRequirementFactory<T> : IPolicyRequirementFactory<T> where T : IPolicyRequirement
{
/// <summary>
/// User roles that are exempt from policy enforcement.
/// Owners and Admins are exempt by default but this may be overridden.
/// </summary>
protected virtual IEnumerable<OrganizationUserType> ExemptRoles { get; } =
[OrganizationUserType.Owner, OrganizationUserType.Admin];
/// <summary>
/// User statuses that are exempt from policy enforcement.
/// Invited and Revoked users are exempt by default, which is appropriate in the majority of cases.
/// </summary>
protected virtual IEnumerable<OrganizationUserStatusType> ExemptStatuses { get; } =
[OrganizationUserStatusType.Invited, OrganizationUserStatusType.Revoked];
/// <summary>
/// Whether a Provider User for the organization is exempt from policy enforcement.
/// Provider Users are exempt by default, which is appropriate in the majority of cases.
/// </summary>
protected virtual bool ExemptProviders { get; } = true;
/// <inheritdoc />
public abstract PolicyType PolicyType { get; }
public bool Enforce(PolicyDetails policyDetails)
=> !policyDetails.HasRole(ExemptRoles) &&
!policyDetails.HasStatus(ExemptStatuses) &&
(!policyDetails.IsProvider || !ExemptProviders);
/// <inheritdoc />
public abstract T Create(IEnumerable<PolicyDetails> policyDetails);
}

View File

@ -0,0 +1,27 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
/// <summary>
/// Policy requirements for the Disable Send policy.
/// </summary>
public class DisableSendPolicyRequirement : IPolicyRequirement
{
/// <summary>
/// Indicates whether Send is disabled for the user. If true, the user should not be able to create or edit Sends.
/// They may still delete existing Sends.
/// </summary>
public bool DisableSend { get; init; }
}
public class DisableSendPolicyRequirementFactory : BasePolicyRequirementFactory<DisableSendPolicyRequirement>
{
public override PolicyType PolicyType => PolicyType.DisableSend;
public override DisableSendPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails)
{
var result = new DisableSendPolicyRequirement { DisableSend = policyDetails.Any() };
return result;
}
}

View File

@ -1,24 +1,11 @@
#nullable enable using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
/// <summary> /// <summary>
/// Represents the business requirements of how one or more enterprise policies will be enforced against a user. /// An object that represents how a <see cref="PolicyType"/> will be enforced against a user.
/// The implementation of this interface will depend on how the policies are enforced in the relevant domain. /// This acts as a bridge between the <see cref="Policy"/> entity saved to the database and the domain that the policy
/// affects. You may represent the impact of the policy in any way that makes sense for the domain.
/// </summary> /// </summary>
public interface IPolicyRequirement; public interface IPolicyRequirement;
/// <summary>
/// A factory function that takes a sequence of <see cref="PolicyDetails"/> and transforms them into a single
/// <see cref="IPolicyRequirement"/> for consumption by the relevant domain. This will receive *all* policy types
/// that may be enforced against a user; when implementing this delegate, you must filter out irrelevant policy types
/// as well as policies that should not be enforced against a user (e.g. due to the user's role or status).
/// </summary>
/// <remarks>
/// See <see cref="PolicyRequirementHelpers"/> for extension methods to handle common requirements when implementing
/// this delegate.
/// </remarks>
public delegate T RequirementFactory<out T>(IEnumerable<PolicyDetails> policyDetails)
where T : IPolicyRequirement;

View File

@ -0,0 +1,39 @@
#nullable enable
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
/// <summary>
/// An interface that defines how to create a single <see cref="IPolicyRequirement"/> from a sequence of
/// <see cref="PolicyDetails"/>.
/// </summary>
/// <typeparam name="T">The <see cref="IPolicyRequirement"/> that the factory produces.</typeparam>
/// <remarks>
/// See <see cref="BasePolicyRequirementFactory{T}"/> for a simple base implementation suitable for most policies.
/// </remarks>
public interface IPolicyRequirementFactory<out T> where T : IPolicyRequirement
{
/// <summary>
/// The <see cref="PolicyType"/> that the requirement relates to.
/// </summary>
PolicyType PolicyType { get; }
/// <summary>
/// A predicate that determines whether a policy should be enforced against the user.
/// </summary>
/// <remarks>Use this to exempt users based on their role, status or other attributes.</remarks>
/// <param name="policyDetails">Policy details for the defined PolicyType.</param>
/// <returns>True if the policy should be enforced against the user, false otherwise.</returns>
bool Enforce(PolicyDetails policyDetails);
/// <summary>
/// A reducer method that creates a single <see cref="IPolicyRequirement"/> from a set of PolicyDetails.
/// </summary>
/// <param name="policyDetails">
/// PolicyDetails for the specified PolicyType, after they have been filtered by the Enforce predicate. That is,
/// this is the final interface to be called.
/// </param>
T Create(IEnumerable<PolicyDetails> policyDetails);
}

View File

@ -1,5 +1,4 @@
using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.Enums; using Bit.Core.Enums;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
@ -7,35 +6,16 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements
public static class PolicyRequirementHelpers public static class PolicyRequirementHelpers
{ {
/// <summary> /// <summary>
/// Filters the PolicyDetails by PolicyType. This is generally required to only get the PolicyDetails that your /// Returns true if the <see cref="PolicyDetails"/> is for one of the specified roles, false otherwise.
/// IPolicyRequirement relates to.
/// </summary> /// </summary>
public static IEnumerable<PolicyDetails> GetPolicyType( public static bool HasRole(
this IEnumerable<PolicyDetails> policyDetails, this PolicyDetails policyDetails,
PolicyType type)
=> policyDetails.Where(x => x.PolicyType == type);
/// <summary>
/// Filters the PolicyDetails to remove the specified user roles. This can be used to exempt
/// owners and admins from policy enforcement.
/// </summary>
public static IEnumerable<PolicyDetails> ExemptRoles(
this IEnumerable<PolicyDetails> policyDetails,
IEnumerable<OrganizationUserType> roles) IEnumerable<OrganizationUserType> roles)
=> policyDetails.Where(x => !roles.Contains(x.OrganizationUserType)); => roles.Contains(policyDetails.OrganizationUserType);
/// <summary> /// <summary>
/// Filters the PolicyDetails to remove organization users who are also provider users for the organization. /// Returns true if the <see cref="PolicyDetails"/> relates to one of the specified statuses, false otherwise.
/// This can be used to exempt provider users from policy enforcement.
/// </summary> /// </summary>
public static IEnumerable<PolicyDetails> ExemptProviders(this IEnumerable<PolicyDetails> policyDetails) public static bool HasStatus(this PolicyDetails policyDetails, IEnumerable<OrganizationUserStatusType> status)
=> policyDetails.Where(x => !x.IsProvider); => status.Contains(policyDetails.OrganizationUserStatus);
/// <summary>
/// Filters the PolicyDetails to remove the specified organization user statuses. For example, this can be used
/// to exempt users in the invited and revoked statuses from policy enforcement.
/// </summary>
public static IEnumerable<PolicyDetails> ExemptStatus(
this IEnumerable<PolicyDetails> policyDetails, IEnumerable<OrganizationUserStatusType> status)
=> policyDetails.Where(x => !status.Contains(x.OrganizationUserStatus));
} }

View File

@ -0,0 +1,46 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.Enums;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
/// <summary>
/// Policy requirements for the Account recovery administration policy.
/// </summary>
public class ResetPasswordPolicyRequirement : IPolicyRequirement
{
/// <summary>
/// List of Organization Ids that require automatic enrollment in password recovery.
/// </summary>
private IEnumerable<Guid> _autoEnrollOrganizations;
public IEnumerable<Guid> AutoEnrollOrganizations { init => _autoEnrollOrganizations = value; }
/// <summary>
/// Returns true if provided organizationId requires automatic enrollment in password recovery.
/// </summary>
public bool AutoEnrollEnabled(Guid organizationId)
{
return _autoEnrollOrganizations.Contains(organizationId);
}
}
public class ResetPasswordPolicyRequirementFactory : BasePolicyRequirementFactory<ResetPasswordPolicyRequirement>
{
public override PolicyType PolicyType => PolicyType.ResetPassword;
protected override bool ExemptProviders => false;
protected override IEnumerable<OrganizationUserType> ExemptRoles => [];
public override ResetPasswordPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails)
{
var result = policyDetails
.Where(p => p.GetDataModel<ResetPasswordDataModel>().AutoEnrollEnabled)
.Select(p => p.OrganizationId)
.ToHashSet();
return new ResetPasswordPolicyRequirement() { AutoEnrollOrganizations = result };
}
}

View File

@ -0,0 +1,34 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
/// <summary>
/// Policy requirements for the Send Options policy.
/// </summary>
public class SendOptionsPolicyRequirement : IPolicyRequirement
{
/// <summary>
/// Indicates whether the user is prohibited from hiding their email from the recipient of a Send.
/// </summary>
public bool DisableHideEmail { get; init; }
}
public class SendOptionsPolicyRequirementFactory : BasePolicyRequirementFactory<SendOptionsPolicyRequirement>
{
public override PolicyType PolicyType => PolicyType.SendOptions;
public override SendOptionsPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails)
{
var result = policyDetails
.Select(p => p.GetDataModel<SendOptionsPolicyData>())
.Aggregate(
new SendOptionsPolicyRequirement(),
(result, data) => new SendOptionsPolicyRequirement
{
DisableHideEmail = result.DisableHideEmail || data.DisableHideEmail
});
return result;
}
}

View File

@ -1,54 +0,0 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.Enums;
namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
/// <summary>
/// Policy requirements for the Disable Send and Send Options policies.
/// </summary>
public class SendPolicyRequirement : IPolicyRequirement
{
/// <summary>
/// Indicates whether Send is disabled for the user. If true, the user should not be able to create or edit Sends.
/// They may still delete existing Sends.
/// </summary>
public bool DisableSend { get; init; }
/// <summary>
/// Indicates whether the user is prohibited from hiding their email from the recipient of a Send.
/// </summary>
public bool DisableHideEmail { get; init; }
/// <summary>
/// Create a new SendPolicyRequirement.
/// </summary>
/// <param name="policyDetails">All PolicyDetails relating to the user.</param>
/// <remarks>
/// This is a <see cref="RequirementFactory{T}"/> for the SendPolicyRequirement.
/// </remarks>
public static SendPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails)
{
var filteredPolicies = policyDetails
.ExemptRoles([OrganizationUserType.Owner, OrganizationUserType.Admin])
.ExemptStatus([OrganizationUserStatusType.Invited, OrganizationUserStatusType.Revoked])
.ExemptProviders()
.ToList();
var result = filteredPolicies
.GetPolicyType(PolicyType.SendOptions)
.Select(p => p.GetDataModel<SendOptionsPolicyData>())
.Aggregate(
new SendPolicyRequirement
{
// Set Disable Send requirement in the initial seed
DisableSend = filteredPolicies.GetPolicyType(PolicyType.DisableSend).Any()
},
(result, data) => new SendPolicyRequirement
{
DisableSend = result.DisableSend,
DisableHideEmail = result.DisableHideEmail || data.DisableHideEmail
});
return result;
}
}

View File

@ -31,32 +31,8 @@ public static class PolicyServiceCollectionExtensions
private static void AddPolicyRequirements(this IServiceCollection services) private static void AddPolicyRequirements(this IServiceCollection services)
{ {
// Register policy requirement factories here services.AddScoped<IPolicyRequirementFactory<IPolicyRequirement>, DisableSendPolicyRequirementFactory>();
services.AddPolicyRequirement(SendPolicyRequirement.Create); services.AddScoped<IPolicyRequirementFactory<IPolicyRequirement>, SendOptionsPolicyRequirementFactory>();
services.AddScoped<IPolicyRequirementFactory<IPolicyRequirement>, ResetPasswordPolicyRequirementFactory>();
} }
/// <summary>
/// Used to register simple policy requirements where its factory method implements CreateRequirement.
/// This MUST be used rather than calling AddScoped directly, because it will ensure the factory method has
/// the correct type to be injected and then identified by <see cref="PolicyRequirementQuery"/> at runtime.
/// </summary>
/// <typeparam name="T">The specific PolicyRequirement being registered.</typeparam>
private static void AddPolicyRequirement<T>(this IServiceCollection serviceCollection, RequirementFactory<T> factory)
where T : class, IPolicyRequirement
=> serviceCollection.AddPolicyRequirement(_ => factory);
/// <summary>
/// Used to register policy requirements where you need to access additional dependencies (usually to return a
/// curried factory method).
/// This MUST be used rather than calling AddScoped directly, because it will ensure the factory method has
/// the correct type to be injected and then identified by <see cref="PolicyRequirementQuery"/> at runtime.
/// </summary>
/// <typeparam name="T">
/// A callback that takes IServiceProvider and returns a RequirementFactory for
/// your policy requirement.
/// </typeparam>
private static void AddPolicyRequirement<T>(this IServiceCollection serviceCollection,
Func<IServiceProvider, RequirementFactory<T>> factory)
where T : class, IPolicyRequirement
=> serviceCollection.AddScoped<RequirementFactory<IPolicyRequirement>>(factory);
} }

View File

@ -6,6 +6,8 @@ using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.Models.Business;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services;
using Bit.Core.Auth.Enums; using Bit.Core.Auth.Enums;
@ -76,6 +78,7 @@ public class OrganizationService : IOrganizationService
private readonly IOrganizationBillingService _organizationBillingService; private readonly IOrganizationBillingService _organizationBillingService;
private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery; private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
public OrganizationService( public OrganizationService(
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
@ -111,7 +114,8 @@ public class OrganizationService : IOrganizationService
ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery,
IOrganizationBillingService organizationBillingService, IOrganizationBillingService organizationBillingService,
IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery,
IPricingClient pricingClient) IPricingClient pricingClient,
IPolicyRequirementQuery policyRequirementQuery)
{ {
_organizationRepository = organizationRepository; _organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository; _organizationUserRepository = organizationUserRepository;
@ -147,6 +151,7 @@ public class OrganizationService : IOrganizationService
_organizationBillingService = organizationBillingService; _organizationBillingService = organizationBillingService;
_hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery; _hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery;
_pricingClient = pricingClient; _pricingClient = pricingClient;
_policyRequirementQuery = policyRequirementQuery;
} }
public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken,
@ -1353,13 +1358,25 @@ public class OrganizationService : IOrganizationService
} }
// Block the user from withdrawal if auto enrollment is enabled // Block the user from withdrawal if auto enrollment is enabled
if (resetPasswordKey == null && resetPasswordPolicy.Data != null) if (_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements))
{ {
var data = JsonSerializer.Deserialize<ResetPasswordDataModel>(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); var resetPasswordPolicyRequirement = await _policyRequirementQuery.GetAsync<ResetPasswordPolicyRequirement>(userId);
if (resetPasswordKey == null && resetPasswordPolicyRequirement.AutoEnrollEnabled(organizationId))
if (data?.AutoEnrollEnabled ?? false)
{ {
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to withdraw from Password Reset."); throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to withdraw from account recovery.");
}
}
else
{
if (resetPasswordKey == null && resetPasswordPolicy.Data != null)
{
var data = JsonSerializer.Deserialize<ResetPasswordDataModel>(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase);
if (data?.AutoEnrollEnabled ?? false)
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to withdraw from account recovery.");
}
} }
} }

View File

@ -0,0 +1,6 @@
namespace Bit.Core.AdminConsole.Shared.Validation;
public interface IValidator<T>
{
public Task<ValidationResult<T>> ValidateAsync(T value);
}

View File

@ -0,0 +1,15 @@
using Bit.Core.AdminConsole.Errors;
namespace Bit.Core.AdminConsole.Shared.Validation;
public abstract record ValidationResult<T>;
public record Valid<T> : ValidationResult<T>
{
public T Value { get; init; }
}
public record Invalid<T> : ValidationResult<T>
{
public IEnumerable<Error<T>> Errors { get; init; }
}

View File

@ -16,6 +16,12 @@ public class AuthRequest : ITableObject<Guid>
public DeviceType RequestDeviceType { get; set; } public DeviceType RequestDeviceType { get; set; }
[MaxLength(50)] [MaxLength(50)]
public string RequestIpAddress { get; set; } public string RequestIpAddress { get; set; }
/// <summary>
/// This country name is populated through a header value fetched from the ISO-3166 country code.
/// It will always be the English short form of the country name. The length should never be over 200 characters.
/// </summary>
[MaxLength(200)]
public string RequestCountryName { get; set; }
public Guid? ResponseDeviceId { get; set; } public Guid? ResponseDeviceId { get; set; }
[MaxLength(25)] [MaxLength(25)]
public string AccessCode { get; set; } public string AccessCode { get; set; }

View File

@ -164,6 +164,7 @@ public class AuthRequestService : IAuthRequestService
RequestDeviceIdentifier = model.DeviceIdentifier, RequestDeviceIdentifier = model.DeviceIdentifier,
RequestDeviceType = _currentContext.DeviceType.Value, RequestDeviceType = _currentContext.DeviceType.Value,
RequestIpAddress = _currentContext.IpAddress, RequestIpAddress = _currentContext.IpAddress,
RequestCountryName = _currentContext.CountryName,
AccessCode = model.AccessCode, AccessCode = model.AccessCode,
PublicKey = model.PublicKey, PublicKey = model.PublicKey,
UserId = user.Id, UserId = user.Id,
@ -176,12 +177,7 @@ public class AuthRequestService : IAuthRequestService
public async Task<AuthRequest> UpdateAuthRequestAsync(Guid authRequestId, Guid currentUserId, AuthRequestUpdateRequestModel model) public async Task<AuthRequest> UpdateAuthRequestAsync(Guid authRequestId, Guid currentUserId, AuthRequestUpdateRequestModel model)
{ {
var authRequest = await _authRequestRepository.GetByIdAsync(authRequestId); var authRequest = await _authRequestRepository.GetByIdAsync(authRequestId) ?? throw new NotFoundException();
if (authRequest == null)
{
throw new NotFoundException();
}
// Once Approval/Disapproval has been set, this AuthRequest should not be updated again. // Once Approval/Disapproval has been set, this AuthRequest should not be updated again.
if (authRequest.Approved is not null) if (authRequest.Approved is not null)

View File

@ -22,4 +22,9 @@ public static class CustomerExtensions
/// <returns></returns> /// <returns></returns>
public static bool HasTaxLocationVerified(this Customer customer) => public static bool HasTaxLocationVerified(this Customer customer) =>
customer?.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; customer?.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported;
public static decimal GetBillingBalance(this Customer customer)
{
return customer != null ? customer.Balance / 100M : default;
}
} }

View File

@ -0,0 +1,26 @@
using Bit.Core.Entities;
namespace Bit.Core.Billing.Extensions;
public static class SubscriberExtensions
{
/// <summary>
/// We are taking only first 30 characters of the SubscriberName because stripe provide for 30 characters for
/// custom_fields,see the link: https://stripe.com/docs/api/invoices/create
/// </summary>
/// <param name="subscriber"></param>
/// <returns></returns>
public static string GetFormattedInvoiceName(this ISubscriber subscriber)
{
var subscriberName = subscriber.SubscriberName();
if (string.IsNullOrWhiteSpace(subscriberName))
{
return string.Empty;
}
return subscriberName.Length <= 30
? subscriberName
: subscriberName[..30];
}
}

View File

@ -95,5 +95,16 @@ public interface IProviderBillingService
Task<Subscription> SetupSubscription( Task<Subscription> SetupSubscription(
Provider provider); Provider provider);
/// <summary>
/// Updates the <paramref name="provider"/>'s payment source and tax information and then sets their subscription's collection_method to be "charge_automatically".
/// </summary>
/// <param name="provider">The <paramref name="provider"/> to update the payment source and tax information for.</param>
/// <param name="tokenizedPaymentSource">The tokenized payment source (ex. Credit Card) to attach to the <paramref name="provider"/>.</param>
/// <param name="taxInformation">The <paramref name="provider"/>'s updated tax information.</param>
Task UpdatePaymentMethod(
Provider provider,
TokenizedPaymentSource tokenizedPaymentSource,
TaxInformation taxInformation);
Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command); Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command);
} }

View File

@ -114,6 +114,20 @@ public static class FeatureFlagKeys
public const string ItemShare = "item-share"; public const string ItemShare = "item-share";
public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application"; public const string RiskInsightsCriticalApplication = "pm-14466-risk-insights-critical-application";
public const string EnableRiskInsightsNotifications = "enable-risk-insights-notifications"; public const string EnableRiskInsightsNotifications = "enable-risk-insights-notifications";
public const string DesktopSendUIRefresh = "desktop-send-ui-refresh";
public const string ExportAttachments = "export-attachments";
/* Vault Team */
public const string PM8851_BrowserOnboardingNudge = "pm-8851-browser-onboarding-nudge";
public const string PM9111ExtensionPersistAddEditForm = "pm-9111-extension-persist-add-edit-form";
public const string NewDeviceVerificationPermanentDismiss = "new-device-permanent-dismiss";
public const string NewDeviceVerificationTemporaryDismiss = "new-device-temporary-dismiss";
public const string VaultBulkManagementAction = "vault-bulk-management-action";
public const string RestrictProviderAccess = "restrict-provider-access";
public const string SecurityTasks = "security-tasks";
/* Auth Team */
public const string PM9112DeviceApprovalPersistence = "pm-9112-device-approval-persistence";
public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair";
public const string UseTreeWalkerApiForPageDetailsCollection = "use-tree-walker-api-for-page-details-collection"; public const string UseTreeWalkerApiForPageDetailsCollection = "use-tree-walker-api-for-page-details-collection";
@ -121,10 +135,7 @@ public static class FeatureFlagKeys
public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email"; public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email";
public const string EmailVerification = "email-verification"; public const string EmailVerification = "email-verification";
public const string EmailVerificationDisableTimingDelays = "email-verification-disable-timing-delays"; public const string EmailVerificationDisableTimingDelays = "email-verification-disable-timing-delays";
public const string ExtensionRefresh = "extension-refresh";
public const string RestrictProviderAccess = "restrict-provider-access";
public const string PM4154BulkEncryptionService = "PM-4154-bulk-encryption-service"; public const string PM4154BulkEncryptionService = "PM-4154-bulk-encryption-service";
public const string VaultBulkManagementAction = "vault-bulk-management-action";
public const string InlineMenuFieldQualification = "inline-menu-field-qualification"; public const string InlineMenuFieldQualification = "inline-menu-field-qualification";
public const string InlineMenuPositioningImprovements = "inline-menu-positioning-improvements"; public const string InlineMenuPositioningImprovements = "inline-menu-positioning-improvements";
public const string DeviceTrustLogging = "pm-8285-device-trust-logging"; public const string DeviceTrustLogging = "pm-8285-device-trust-logging";
@ -149,13 +160,8 @@ public static class FeatureFlagKeys
public const string RemoveServerVersionHeader = "remove-server-version-header"; public const string RemoveServerVersionHeader = "remove-server-version-header";
public const string GeneratorToolsModernization = "generator-tools-modernization"; public const string GeneratorToolsModernization = "generator-tools-modernization";
public const string NewDeviceVerification = "new-device-verification"; public const string NewDeviceVerification = "new-device-verification";
public const string NewDeviceVerificationTemporaryDismiss = "new-device-temporary-dismiss";
public const string NewDeviceVerificationPermanentDismiss = "new-device-permanent-dismiss";
public const string SecurityTasks = "security-tasks";
public const string MacOsNativeCredentialSync = "macos-native-credential-sync"; public const string MacOsNativeCredentialSync = "macos-native-credential-sync";
public const string PM9111ExtensionPersistAddEditForm = "pm-9111-extension-persist-add-edit-form";
public const string InlineMenuTotp = "inline-menu-totp"; public const string InlineMenuTotp = "inline-menu-totp";
public const string SelfHostLicenseRefactor = "pm-11516-self-host-license-refactor";
public const string PrivateKeyRegeneration = "pm-12241-private-key-regeneration"; public const string PrivateKeyRegeneration = "pm-12241-private-key-regeneration";
public const string AppReviewPrompt = "app-review-prompt"; public const string AppReviewPrompt = "app-review-prompt";
public const string ResellerManagedOrgAlert = "PM-15814-alert-owners-of-reseller-managed-orgs"; public const string ResellerManagedOrgAlert = "PM-15814-alert-owners-of-reseller-managed-orgs";
@ -173,6 +179,10 @@ public static class FeatureFlagKeys
public const string WebPush = "web-push"; public const string WebPush = "web-push";
public const string AndroidImportLoginsFlow = "import-logins-flow"; public const string AndroidImportLoginsFlow = "import-logins-flow";
public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features"; public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features";
public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method";
public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias";
public const string SetInitialPasswordRefactor = "pm-16117-set-initial-password-refactor";
public const string ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor";
public static List<string> GetAllKeys() public static List<string> GetAllKeys()
{ {

View File

@ -30,6 +30,7 @@ public class CurrentContext : ICurrentContext
public virtual string DeviceIdentifier { get; set; } public virtual string DeviceIdentifier { get; set; }
public virtual DeviceType? DeviceType { get; set; } public virtual DeviceType? DeviceType { get; set; }
public virtual string IpAddress { get; set; } public virtual string IpAddress { get; set; }
public virtual string CountryName { get; set; }
public virtual List<CurrentContextOrganization> Organizations { get; set; } public virtual List<CurrentContextOrganization> Organizations { get; set; }
public virtual List<CurrentContextProvider> Providers { get; set; } public virtual List<CurrentContextProvider> Providers { get; set; }
public virtual Guid? InstallationId { get; set; } public virtual Guid? InstallationId { get; set; }
@ -104,6 +105,12 @@ public class CurrentContext : ICurrentContext
{ {
ClientVersionIsPrerelease = clientVersionIsPrerelease == "1"; ClientVersionIsPrerelease = clientVersionIsPrerelease == "1";
} }
if (httpContext.Request.Headers.TryGetValue("country-name", out var countryName))
{
CountryName = countryName;
}
} }
public async virtual Task BuildAsync(ClaimsPrincipal user, GlobalSettings globalSettings) public async virtual Task BuildAsync(ClaimsPrincipal user, GlobalSettings globalSettings)

View File

@ -20,6 +20,7 @@ public interface ICurrentContext
string DeviceIdentifier { get; set; } string DeviceIdentifier { get; set; }
DeviceType? DeviceType { get; set; } DeviceType? DeviceType { get; set; }
string IpAddress { get; set; } string IpAddress { get; set; }
string CountryName { get; set; }
List<CurrentContextOrganization> Organizations { get; set; } List<CurrentContextOrganization> Organizations { get; set; }
Guid? InstallationId { get; set; } Guid? InstallationId { get; set; }
Guid? OrganizationId { get; set; } Guid? OrganizationId { get; set; }

View File

@ -36,7 +36,7 @@
<PackageReference Include="DnsClient" Version="1.8.0" /> <PackageReference Include="DnsClient" Version="1.8.0" />
<PackageReference Include="Fido2.AspNet" Version="3.0.1" /> <PackageReference Include="Fido2.AspNet" Version="3.0.1" />
<PackageReference Include="Handlebars.Net" Version="2.1.6" /> <PackageReference Include="Handlebars.Net" Version="2.1.6" />
<PackageReference Include="MailKit" Version="4.10.0" /> <PackageReference Include="MailKit" Version="4.11.0" />
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="8.0.10" /> <PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="8.0.10" />
<PackageReference Include="Microsoft.Azure.Cosmos" Version="3.46.1" /> <PackageReference Include="Microsoft.Azure.Cosmos" Version="3.46.1" />
<PackageReference Include="Microsoft.Azure.NotificationHubs" Version="4.2.0" /> <PackageReference Include="Microsoft.Azure.NotificationHubs" Version="4.2.0" />

View File

@ -7,7 +7,7 @@
</tr> </tr>
<tr style="margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> <tr style="margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
<td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none;" valign="top" align="left"> <td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none;" valign="top" align="left">
To leave an organization, first log into the <a href="https://vault.bitwarden.com/#/login">web app</a>, select the three dot menu next to the organization name, and select Leave. To leave an organization, first log into the <a href="{{{WebVaultUrl}}}/login">web app</a>, select the three dot menu next to the organization name, and select Leave.
</td> </td>
</tr> </tr>
</table> </table>

View File

@ -1,5 +1,5 @@
{{#>BasicTextLayout}} {{#>BasicTextLayout}}
Your user account has been revoked from the {{OrganizationName}} organization because your account is part of multiple organizations. Before you can rejoin {{OrganizationName}}, you must first leave all other organizations. Your user account has been revoked from the {{OrganizationName}} organization because your account is part of multiple organizations. Before you can rejoin {{OrganizationName}}, you must first leave all other organizations.
To leave an organization, first log in the web app (https://vault.bitwarden.com/#/login), select the three dot menu next to the organization name, and select Leave. To leave an organization, first log in the web app ({{{WebVaultUrl}}}/login), select the three dot menu next to the organization name, and select Leave.
{{/BasicTextLayout}} {{/BasicTextLayout}}

View File

@ -26,7 +26,7 @@
</tr> </tr>
<tr style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> <tr style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
<td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none; text-align: center;" valign="top" align="center"> <td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none; text-align: center;" valign="top" align="center">
<a href="https://vault.bitwarden.com/#/organizations/{{{OrganizationId}}}/billing/subscription" clicktracking=off target="_blank" style="color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> <a href="{{{VaultSubscriptionUrl}}}" clicktracking=off target="_blank" style="color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
Manage subscription Manage subscription
</a> </a>
<br class="line-break" /> <br class="line-break" />

View File

@ -24,7 +24,7 @@
</tr> </tr>
<tr style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> <tr style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
<td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none; text-align: center;" valign="top" align="center"> <td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none; text-align: center;" valign="top" align="center">
<a href="https://vault.bitwarden.com/#/organizations/{{{OrganizationId}}}/billing/subscription" clicktracking=off target="_blank" style="color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> <a href="{{{VaultSubscriptionUrl}}}" clicktracking=off target="_blank" style="color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
Manage subscription Manage subscription
</a> </a>
<br class="line-break" /> <br class="line-break" />

View File

@ -24,7 +24,7 @@
</tr> </tr>
<tr style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> <tr style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
<td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none; text-align: center;" valign="top" align="center"> <td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none; text-align: center;" valign="top" align="center">
<a href="https://vault.bitwarden.com/#/organizations/{{{OrganizationId}}}/billing/subscription" clicktracking=off target="_blank" rel="noopener noreferrer" style="color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> <a href="{{{VaultSubscriptionUrl}}}" clicktracking=off target="_blank" rel="noopener noreferrer" style="color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
Manage subscription Manage subscription
</a> </a>
<br class="line-break" /> <br class="line-break" />

View File

@ -24,7 +24,7 @@
</tr> </tr>
<tr style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> <tr style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
<td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none; text-align: center;" valign="top" align="center"> <td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none; text-align: center;" valign="top" align="center">
<a href="https://vault.bitwarden.com/#/organizations/{{{OrganizationId}}}/billing/subscription" clicktracking=off target="_blank" rel="noopener noreferrer" style="color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> <a href="{{{VaultSubscriptionUrl}}}" clicktracking=off target="_blank" rel="noopener noreferrer" style="color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
Manage subscription Manage subscription
</a> </a>
<br class="line-break" /> <br class="line-break" />

View File

@ -15,14 +15,21 @@
</tr> </tr>
</table> </table>
<table width="100%" border="0" cellpadding="0" cellspacing="0" <table width="100%" border="0" cellpadding="0" cellspacing="0"
style="display: table; width:100%; padding-bottom: 35px; text-align: center;" align="center"> style="display: table; width:100%; padding-bottom: 24px; text-align: center;" align="center">
<tr> <tr>
<td display="display: table-cell"> <td display="display: table-cell">
<a href="{{ReviewPasswordsUrl}}" clicktracking=off target="_blank" <a href="{{ReviewPasswordsUrl}}" clicktracking=off target="_blank"
style="display: inline-block; color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; border-radius: 999px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;"> style="display: inline-block; font-weight: bold; color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; border-radius: 999px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
Review at-risk passwords Review at-risk passwords
</a> </a>
</td> </td>
</tr> </tr>
<table width="100%" border="0" cellpadding="0" cellspacing="0"
style="display: table; width:100%; padding-bottom: 24px; text-align: center;" align="center">
<tr>
<td display="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; font-style: normal; font-weight: 400; font-size: 12px; line-height: 16px;">
{{formatAdminOwnerEmails AdminOwnerEmails}}
</td>
</tr>
</table> </table>
{{/SecurityTasksHtmlLayout}} {{/SecurityTasksHtmlLayout}}

View File

@ -5,4 +5,13 @@ breach.
Launch the Bitwarden extension to review your at-risk passwords. Launch the Bitwarden extension to review your at-risk passwords.
Review at-risk passwords ({{{ReviewPasswordsUrl}}}) Review at-risk passwords ({{{ReviewPasswordsUrl}}})
{{#if (eq (length AdminOwnerEmails) 1)}}
This request was initiated by {{AdminOwnerEmails.[0]}}.
{{else}}
This request was initiated by
{{#each AdminOwnerEmails}}
{{#if @last}}and {{/if}}{{this}}{{#unless @last}}, {{/unless}}
{{/each}}.
{{/if}}
{{/SecurityTasksHtmlLayout}} {{/SecurityTasksHtmlLayout}}

View File

@ -73,8 +73,11 @@ public class SubscriptionInfo
Name = item.Plan.Nickname; Name = item.Plan.Nickname;
Amount = item.Plan.Amount.GetValueOrDefault() / 100M; Amount = item.Plan.Amount.GetValueOrDefault() / 100M;
Interval = item.Plan.Interval; Interval = item.Plan.Interval;
AddonSubscriptionItem =
Utilities.StaticStore.IsAddonSubscriptionItem(item.Plan.Id); if (item.Metadata != null)
{
AddonSubscriptionItem = item.Metadata.TryGetValue("isAddOn", out var value) && bool.Parse(value);
}
} }
Quantity = (int)item.Quantity; Quantity = (int)item.Quantity;

View File

@ -1,5 +1,7 @@
#nullable enable #nullable enable
using Bit.Core.AdminConsole.Errors;
namespace Bit.Core.Models.Commands; namespace Bit.Core.Models.Commands;
public class CommandResult(IEnumerable<string> errors) public class CommandResult(IEnumerable<string> errors)
@ -9,7 +11,6 @@ public class CommandResult(IEnumerable<string> errors)
public bool Success => ErrorMessages.Count == 0; public bool Success => ErrorMessages.Count == 0;
public bool HasErrors => ErrorMessages.Count > 0; public bool HasErrors => ErrorMessages.Count > 0;
public List<string> ErrorMessages { get; } = errors.ToList(); public List<string> ErrorMessages { get; } = errors.ToList();
public CommandResult() : this(Array.Empty<string>()) { } public CommandResult() : this(Array.Empty<string>()) { }
} }
@ -29,22 +30,30 @@ public class Success : CommandResult
{ {
} }
public abstract class CommandResult<T> public abstract class CommandResult<T>;
{
public class Success<T>(T value) : CommandResult<T>
{
public T Value { get; } = value;
} }
public class Success<T>(T data) : CommandResult<T> public class Failure<T>(IEnumerable<string> errorMessages) : CommandResult<T>
{ {
public T? Data { get; init; } = data; public List<string> ErrorMessages { get; } = errorMessages.ToList();
public string ErrorMessage => string.Join(" ", ErrorMessages);
public Failure(string error) : this([error]) { }
} }
public class Failure<T>(IEnumerable<string> errorMessage) : CommandResult<T> public class Partial<T> : CommandResult<T>
{ {
public IEnumerable<string> ErrorMessages { get; init; } = errorMessage; public T[] Successes { get; set; } = [];
public Error<T>[] Failures { get; set; } = [];
public Failure(string errorMessage) : this(new[] { errorMessage }) public Partial(IEnumerable<T> successfulItems, IEnumerable<Error<T>> failedItems)
{ {
Successes = successfulItems.ToArray();
Failures = failedItems.ToArray();
} }
} }

View File

@ -2,7 +2,7 @@
public class OrganizationSeatsAutoscaledViewModel : BaseMailModel public class OrganizationSeatsAutoscaledViewModel : BaseMailModel
{ {
public Guid OrganizationId { get; set; }
public int InitialSeatCount { get; set; } public int InitialSeatCount { get; set; }
public int CurrentSeatCount { get; set; } public int CurrentSeatCount { get; set; }
public string VaultSubscriptionUrl { get; set; }
} }

View File

@ -2,6 +2,6 @@
public class OrganizationSeatsMaxReachedViewModel : BaseMailModel public class OrganizationSeatsMaxReachedViewModel : BaseMailModel
{ {
public Guid OrganizationId { get; set; }
public int MaxSeatCount { get; set; } public int MaxSeatCount { get; set; }
public string VaultSubscriptionUrl { get; set; }
} }

View File

@ -2,6 +2,6 @@
public class OrganizationServiceAccountsMaxReachedViewModel public class OrganizationServiceAccountsMaxReachedViewModel
{ {
public Guid OrganizationId { get; set; }
public int MaxServiceAccountsCount { get; set; } public int MaxServiceAccountsCount { get; set; }
public string VaultSubscriptionUrl { get; set; }
} }

View File

@ -8,5 +8,7 @@ public class SecurityTaskNotificationViewModel : BaseMailModel
public bool TaskCountPlural => TaskCount != 1; public bool TaskCountPlural => TaskCount != 1;
public IEnumerable<string> AdminOwnerEmails { get; set; }
public string ReviewPasswordsUrl => $"{WebVaultUrl}/browser-extension-prompt"; public string ReviewPasswordsUrl => $"{WebVaultUrl}/browser-extension-prompt";
} }

View File

@ -42,10 +42,7 @@ public class CloudGetOrganizationLicenseQuery : ICloudGetOrganizationLicenseQuer
var subscriptionInfo = await GetSubscriptionAsync(organization); var subscriptionInfo = await GetSubscriptionAsync(organization);
var license = new OrganizationLicense(organization, subscriptionInfo, installationId, _licensingService, version); var license = new OrganizationLicense(organization, subscriptionInfo, installationId, _licensingService, version);
if (_featureService.IsEnabled(FeatureFlagKeys.SelfHostLicenseRefactor)) license.Token = await _licensingService.CreateOrganizationTokenAsync(organization, installationId, subscriptionInfo);
{
license.Token = await _licensingService.CreateOrganizationTokenAsync(organization, installationId, subscriptionInfo);
}
return license; return license;
} }

View File

@ -99,5 +99,5 @@ public interface IMailService
string organizationName); string organizationName);
Task SendClaimedDomainUserEmailAsync(ManagedUserDomainClaimedEmails emailList); Task SendClaimedDomainUserEmailAsync(ManagedUserDomainClaimedEmails emailList);
Task SendDeviceApprovalRequestedNotificationEmailAsync(IEnumerable<string> adminEmails, Guid organizationId, string email, string userName); Task SendDeviceApprovalRequestedNotificationEmailAsync(IEnumerable<string> adminEmails, Guid organizationId, string email, string userName);
Task SendBulkSecurityTaskNotificationsAsync(string orgName, IEnumerable<UserSecurityTasksCount> securityTaskNotificaitons); Task SendBulkSecurityTaskNotificationsAsync(Organization org, IEnumerable<UserSecurityTasksCount> securityTaskNotifications, IEnumerable<string> adminOwnerEmails);
} }

View File

@ -14,18 +14,8 @@ namespace Bit.Core.Services;
public interface IPaymentService public interface IPaymentService
{ {
Task CancelAndRecoverChargesAsync(ISubscriber subscriber); Task CancelAndRecoverChargesAsync(ISubscriber subscriber);
Task<string> PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType,
string paymentToken, Plan plan, short additionalStorageGb, int additionalSeats,
bool premiumAccessAddon, TaxInfo taxInfo, bool provider = false, int additionalSmSeats = 0,
int additionalServiceAccount = 0, bool signupIsFromSecretsManagerTrial = false);
Task<string> PurchaseOrganizationNoPaymentMethod(Organization org, Plan plan, int additionalSeats,
bool premiumAccessAddon, int additionalSmSeats = 0, int additionalServiceAccount = 0,
bool signupIsFromSecretsManagerTrial = false);
Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship); Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship);
Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship); Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship);
Task<string> UpgradeFreeOrganizationAsync(Organization org, Plan plan, OrganizationUpgrade upgrade);
Task<string> PurchasePremiumAsync(User user, PaymentMethodType paymentMethodType, string paymentToken,
short additionalStorageGb, TaxInfo taxInfo);
Task<string> AdjustSubscription( Task<string> AdjustSubscription(
Organization organization, Organization organization,
Plan updatedPlan, Plan updatedPlan,
@ -56,9 +46,7 @@ public interface IPaymentService
Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo); Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo);
Task<string> AddSecretsManagerToSubscription(Organization org, Plan plan, int additionalSmSeats, Task<string> AddSecretsManagerToSubscription(Organization org, Plan plan, int additionalSmSeats,
int additionalServiceAccount); int additionalServiceAccount);
Task<bool> RisksSubscriptionFailure(Organization organization);
Task<bool> HasSecretsManagerStandalone(Organization organization); Task<bool> HasSecretsManagerStandalone(Organization organization);
Task<(DateTime?, DateTime?)> GetSuspensionDateAsync(Stripe.Subscription subscription);
Task<PreviewInvoiceResponseModel> PreviewInvoiceAsync(PreviewIndividualInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId); Task<PreviewInvoiceResponseModel> PreviewInvoiceAsync(PreviewIndividualInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId);
Task<PreviewInvoiceResponseModel> PreviewInvoiceAsync(PreviewOrganizationInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId); Task<PreviewInvoiceResponseModel> PreviewInvoiceAsync(PreviewOrganizationInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId);

View File

@ -214,9 +214,9 @@ public class HandlebarsMailService : IMailService
var message = CreateDefaultMessage($"{organization.DisplayName()} Seat Count Has Increased", ownerEmails); var message = CreateDefaultMessage($"{organization.DisplayName()} Seat Count Has Increased", ownerEmails);
var model = new OrganizationSeatsAutoscaledViewModel var model = new OrganizationSeatsAutoscaledViewModel
{ {
OrganizationId = organization.Id,
InitialSeatCount = initialSeatCount, InitialSeatCount = initialSeatCount,
CurrentSeatCount = organization.Seats.Value, CurrentSeatCount = organization.Seats.Value,
VaultSubscriptionUrl = GetCloudVaultSubscriptionUrl(organization.Id)
}; };
await AddMessageContentAsync(message, "OrganizationSeatsAutoscaled", model); await AddMessageContentAsync(message, "OrganizationSeatsAutoscaled", model);
@ -229,8 +229,8 @@ public class HandlebarsMailService : IMailService
var message = CreateDefaultMessage($"{organization.DisplayName()} Seat Limit Reached", ownerEmails); var message = CreateDefaultMessage($"{organization.DisplayName()} Seat Limit Reached", ownerEmails);
var model = new OrganizationSeatsMaxReachedViewModel var model = new OrganizationSeatsMaxReachedViewModel
{ {
OrganizationId = organization.Id,
MaxSeatCount = maxSeatCount, MaxSeatCount = maxSeatCount,
VaultSubscriptionUrl = GetCloudVaultSubscriptionUrl(organization.Id)
}; };
await AddMessageContentAsync(message, "OrganizationSeatsMaxReached", model); await AddMessageContentAsync(message, "OrganizationSeatsMaxReached", model);
@ -740,6 +740,45 @@ public class HandlebarsMailService : IMailService
var clickTrackingText = (clickTrackingOff ? "clicktracking=off" : string.Empty); var clickTrackingText = (clickTrackingOff ? "clicktracking=off" : string.Empty);
writer.WriteSafeString($"<a href=\"{href}\" target=\"_blank\" {clickTrackingText}>{text}</a>"); writer.WriteSafeString($"<a href=\"{href}\" target=\"_blank\" {clickTrackingText}>{text}</a>");
}); });
// Construct markup for admin and owner email addresses.
// Using conditionals within the handlebar syntax was including extra spaces around
// concatenated strings, which this helper avoids.
Handlebars.RegisterHelper("formatAdminOwnerEmails", (writer, context, parameters) =>
{
if (parameters.Length == 0)
{
writer.WriteSafeString(string.Empty);
return;
}
var emailList = ((IEnumerable<string>)parameters[0]).ToList();
if (emailList.Count == 0)
{
writer.WriteSafeString(string.Empty);
return;
}
string constructAnchorElement(string email)
{
return $"<a style=\"color: #175DDC\" href=\"mailto:{email}\">{email}</a>";
}
var outputMessage = "This request was initiated by ";
if (emailList.Count == 1)
{
outputMessage += $"{constructAnchorElement(emailList[0])}.";
}
else
{
outputMessage += string.Join(", ", emailList.Take(emailList.Count - 1)
.Select(email => constructAnchorElement(email)));
outputMessage += $", and {constructAnchorElement(emailList.Last())}.";
}
writer.WriteSafeString($"{outputMessage}");
});
} }
public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token)
@ -1103,8 +1142,8 @@ public class HandlebarsMailService : IMailService
var message = CreateDefaultMessage($"{organization.DisplayName()} Secrets Manager Seat Limit Reached", ownerEmails); var message = CreateDefaultMessage($"{organization.DisplayName()} Secrets Manager Seat Limit Reached", ownerEmails);
var model = new OrganizationSeatsMaxReachedViewModel var model = new OrganizationSeatsMaxReachedViewModel
{ {
OrganizationId = organization.Id,
MaxSeatCount = maxSeatCount, MaxSeatCount = maxSeatCount,
VaultSubscriptionUrl = GetCloudVaultSubscriptionUrl(organization.Id)
}; };
await AddMessageContentAsync(message, "OrganizationSmSeatsMaxReached", model); await AddMessageContentAsync(message, "OrganizationSmSeatsMaxReached", model);
@ -1118,8 +1157,8 @@ public class HandlebarsMailService : IMailService
var message = CreateDefaultMessage($"{organization.DisplayName()} Secrets Manager Machine Accounts Limit Reached", ownerEmails); var message = CreateDefaultMessage($"{organization.DisplayName()} Secrets Manager Machine Accounts Limit Reached", ownerEmails);
var model = new OrganizationServiceAccountsMaxReachedViewModel var model = new OrganizationServiceAccountsMaxReachedViewModel
{ {
OrganizationId = organization.Id,
MaxServiceAccountsCount = maxSeatCount, MaxServiceAccountsCount = maxSeatCount,
VaultSubscriptionUrl = GetCloudVaultSubscriptionUrl(organization.Id)
}; };
await AddMessageContentAsync(message, "OrganizationSmServiceAccountsMaxReached", model); await AddMessageContentAsync(message, "OrganizationSmServiceAccountsMaxReached", model);
@ -1201,21 +1240,23 @@ public class HandlebarsMailService : IMailService
await _mailDeliveryService.SendEmailAsync(message); await _mailDeliveryService.SendEmailAsync(message);
} }
public async Task SendBulkSecurityTaskNotificationsAsync(string orgName, IEnumerable<UserSecurityTasksCount> securityTaskNotificaitons) public async Task SendBulkSecurityTaskNotificationsAsync(Organization org, IEnumerable<UserSecurityTasksCount> securityTaskNotifications, IEnumerable<string> adminOwnerEmails)
{ {
MailQueueMessage CreateMessage(UserSecurityTasksCount notification) MailQueueMessage CreateMessage(UserSecurityTasksCount notification)
{ {
var message = CreateDefaultMessage($"{orgName} has identified {notification.TaskCount} at-risk password{(notification.TaskCount.Equals(1) ? "" : "s")}", notification.Email); var sanitizedOrgName = CoreHelpers.SanitizeForEmail(org.DisplayName(), false);
var message = CreateDefaultMessage($"{sanitizedOrgName} has identified {notification.TaskCount} at-risk password{(notification.TaskCount.Equals(1) ? "" : "s")}", notification.Email);
var model = new SecurityTaskNotificationViewModel var model = new SecurityTaskNotificationViewModel
{ {
OrgName = orgName, OrgName = CoreHelpers.SanitizeForEmail(sanitizedOrgName, false),
TaskCount = notification.TaskCount, TaskCount = notification.TaskCount,
AdminOwnerEmails = adminOwnerEmails,
WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash,
}; };
message.Category = "SecurityTasksNotification"; message.Category = "SecurityTasksNotification";
return new MailQueueMessage(message, "SecurityTasksNotification", model); return new MailQueueMessage(message, "SecurityTasksNotification", model);
} }
var messageModels = securityTaskNotificaitons.Select(CreateMessage); var messageModels = securityTaskNotifications.Select(CreateMessage);
await EnqueueMailAsync(messageModels.ToList()); await EnqueueMailAsync(messageModels.ToList());
} }
@ -1223,4 +1264,11 @@ public class HandlebarsMailService : IMailService
{ {
return string.IsNullOrEmpty(userName) ? email : CoreHelpers.SanitizeForEmail(userName, false); return string.IsNullOrEmpty(userName) ? email : CoreHelpers.SanitizeForEmail(userName, false);
} }
private string GetCloudVaultSubscriptionUrl(Guid organizationId)
=> _globalSettings.BaseServiceUri.CloudRegion?.ToLower() switch
{
"eu" => $"https://vault.bitwarden.eu/#/organizations/{organizationId}/billing/subscription",
_ => $"https://vault.bitwarden.com/#/organizations/{organizationId}/billing/subscription"
};
} }

View File

@ -25,9 +25,6 @@ namespace Bit.Core.Services;
public class StripePaymentService : IPaymentService public class StripePaymentService : IPaymentService
{ {
private const string PremiumPlanId = "premium-annually";
private const string StoragePlanId = "storage-gb-annually";
private const string ProviderDiscountId = "msp-discount-35";
private const string SecretsManagerStandaloneDiscountId = "sm-standalone"; private const string SecretsManagerStandaloneDiscountId = "sm-standalone";
private readonly ITransactionRepository _transactionRepository; private readonly ITransactionRepository _transactionRepository;
@ -62,240 +59,6 @@ public class StripePaymentService : IPaymentService
_pricingClient = pricingClient; _pricingClient = pricingClient;
} }
public async Task<string> PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType,
string paymentToken, StaticStore.Plan plan, short additionalStorageGb,
int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo, bool provider = false,
int additionalSmSeats = 0, int additionalServiceAccount = 0, bool signupIsFromSecretsManagerTrial = false)
{
Braintree.Customer braintreeCustomer = null;
string stipeCustomerSourceToken = null;
string stipeCustomerPaymentMethodId = null;
var stripeCustomerMetadata = new Dictionary<string, string>
{
{ "region", _globalSettings.BaseServiceUri.CloudRegion }
};
var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card ||
paymentMethodType == PaymentMethodType.BankAccount;
if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken))
{
if (paymentToken.StartsWith("pm_"))
{
stipeCustomerPaymentMethodId = paymentToken;
}
else
{
stipeCustomerSourceToken = paymentToken;
}
}
else if (paymentMethodType == PaymentMethodType.PayPal)
{
var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false);
var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest
{
PaymentMethodNonce = paymentToken,
Email = org.BillingEmail,
Id = org.BraintreeCustomerIdPrefix() + org.Id.ToString("N").ToLower() + randomSuffix,
CustomFields = new Dictionary<string, string>
{
[org.BraintreeIdField()] = org.Id.ToString(),
[org.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion
}
});
if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0)
{
throw new GatewayException("Failed to create PayPal customer record.");
}
braintreeCustomer = customerResult.Target;
stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id);
}
else
{
throw new GatewayException("Payment method is not supported at this time.");
}
var subCreateOptions = new OrganizationPurchaseSubscriptionOptions(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon
, additionalSmSeats, additionalServiceAccount);
Customer customer = null;
Subscription subscription;
try
{
if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber))
{
taxInfo.TaxIdType = _taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry,
taxInfo.TaxIdNumber);
if (taxInfo.TaxIdType == null)
{
_logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.",
taxInfo.BillingAddressCountry,
taxInfo.TaxIdNumber);
throw new BadRequestException("billingTaxIdTypeInferenceError");
}
}
var customerCreateOptions = new CustomerCreateOptions
{
Description = org.DisplayBusinessName(),
Email = org.BillingEmail,
Source = stipeCustomerSourceToken,
PaymentMethod = stipeCustomerPaymentMethodId,
Metadata = stripeCustomerMetadata,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
DefaultPaymentMethod = stipeCustomerPaymentMethodId,
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = org.SubscriberType(),
Value = GetFirstThirtyCharacters(org.SubscriberName()),
}
],
},
Coupon = signupIsFromSecretsManagerTrial
? SecretsManagerStandaloneDiscountId
: provider
? ProviderDiscountId
: null,
Address = new AddressOptions
{
Country = taxInfo?.BillingAddressCountry,
PostalCode = taxInfo?.BillingAddressPostalCode,
// Line1 is required in Stripe's API, suggestion in Docs is to use Business Name instead.
Line1 = taxInfo?.BillingAddressLine1 ?? string.Empty,
Line2 = taxInfo?.BillingAddressLine2,
City = taxInfo?.BillingAddressCity,
State = taxInfo?.BillingAddressState,
},
TaxIdData = !string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber)
? [new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber }]
: null
};
customerCreateOptions.AddExpand("tax");
customer = await _stripeAdapter.CustomerCreateAsync(customerCreateOptions);
subCreateOptions.AddExpand("latest_invoice.payment_intent");
subCreateOptions.Customer = customer.Id;
subCreateOptions.EnableAutomaticTax(customer);
subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions);
if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null)
{
if (subscription.LatestInvoice.PaymentIntent.Status == "requires_payment_method")
{
await _stripeAdapter.SubscriptionCancelAsync(subscription.Id, new SubscriptionCancelOptions());
throw new GatewayException("Payment method was declined.");
}
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error creating customer, walking back operation.");
if (customer != null)
{
await _stripeAdapter.CustomerDeleteAsync(customer.Id);
}
if (braintreeCustomer != null)
{
await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id);
}
throw;
}
org.Gateway = GatewayType.Stripe;
org.GatewayCustomerId = customer.Id;
org.GatewaySubscriptionId = subscription.Id;
if (subscription.Status == "incomplete" &&
subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action")
{
org.Enabled = false;
return subscription.LatestInvoice.PaymentIntent.ClientSecret;
}
else
{
org.Enabled = true;
org.ExpirationDate = subscription.CurrentPeriodEnd;
return null;
}
}
public async Task<string> PurchaseOrganizationNoPaymentMethod(Organization org, StaticStore.Plan plan, int additionalSeats, bool premiumAccessAddon,
int additionalSmSeats = 0, int additionalServiceAccount = 0, bool signupIsFromSecretsManagerTrial = false)
{
var stripeCustomerMetadata = new Dictionary<string, string>
{
{ "region", _globalSettings.BaseServiceUri.CloudRegion }
};
var subCreateOptions = new OrganizationPurchaseSubscriptionOptions(org, plan, new TaxInfo(), additionalSeats, 0, premiumAccessAddon
, additionalSmSeats, additionalServiceAccount);
Customer customer = null;
Subscription subscription;
try
{
var customerCreateOptions = new CustomerCreateOptions
{
Description = org.DisplayBusinessName(),
Email = org.BillingEmail,
Metadata = stripeCustomerMetadata,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = org.SubscriberType(),
Value = GetFirstThirtyCharacters(org.SubscriberName()),
}
],
},
Coupon = signupIsFromSecretsManagerTrial
? SecretsManagerStandaloneDiscountId
: null,
TaxIdData = null,
};
customer = await _stripeAdapter.CustomerCreateAsync(customerCreateOptions);
subCreateOptions.AddExpand("latest_invoice.payment_intent");
subCreateOptions.Customer = customer.Id;
subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error creating customer, walking back operation.");
if (customer != null)
{
await _stripeAdapter.CustomerDeleteAsync(customer.Id);
}
throw;
}
org.Gateway = GatewayType.Stripe;
org.GatewayCustomerId = customer.Id;
org.GatewaySubscriptionId = subscription.Id;
if (subscription.Status == "incomplete" &&
subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action")
{
org.Enabled = false;
return subscription.LatestInvoice.PaymentIntent.ClientSecret;
}
org.Enabled = true;
org.ExpirationDate = subscription.CurrentPeriodEnd;
return null;
}
private async Task ChangeOrganizationSponsorship( private async Task ChangeOrganizationSponsorship(
Organization org, Organization org,
OrganizationSponsorship sponsorship, OrganizationSponsorship sponsorship,
@ -324,458 +87,6 @@ public class StripePaymentService : IPaymentService
public Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship) => public Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship) =>
ChangeOrganizationSponsorship(org, sponsorship, false); ChangeOrganizationSponsorship(org, sponsorship, false);
public async Task<string> UpgradeFreeOrganizationAsync(Organization org, StaticStore.Plan plan,
OrganizationUpgrade upgrade)
{
if (!string.IsNullOrWhiteSpace(org.GatewaySubscriptionId))
{
throw new BadRequestException("Organization already has a subscription.");
}
var customerOptions = new CustomerGetOptions();
customerOptions.AddExpand("default_source");
customerOptions.AddExpand("invoice_settings.default_payment_method");
customerOptions.AddExpand("tax");
var customer = await _stripeAdapter.CustomerGetAsync(org.GatewayCustomerId, customerOptions);
if (customer == null)
{
throw new GatewayException("Could not find customer payment profile.");
}
if (!string.IsNullOrEmpty(upgrade.TaxInfo?.BillingAddressCountry) &&
!string.IsNullOrEmpty(upgrade.TaxInfo?.BillingAddressPostalCode))
{
var addressOptions = new AddressOptions
{
Country = upgrade.TaxInfo.BillingAddressCountry,
PostalCode = upgrade.TaxInfo.BillingAddressPostalCode,
// Line1 is required in Stripe's API, suggestion in Docs is to use Business Name instead.
Line1 = upgrade.TaxInfo.BillingAddressLine1 ?? string.Empty,
Line2 = upgrade.TaxInfo.BillingAddressLine2,
City = upgrade.TaxInfo.BillingAddressCity,
State = upgrade.TaxInfo.BillingAddressState,
};
var customerUpdateOptions = new CustomerUpdateOptions { Address = addressOptions };
customerUpdateOptions.AddExpand("default_source");
customerUpdateOptions.AddExpand("invoice_settings.default_payment_method");
customerUpdateOptions.AddExpand("tax");
customer = await _stripeAdapter.CustomerUpdateAsync(org.GatewayCustomerId, customerUpdateOptions);
}
var subCreateOptions = new OrganizationUpgradeSubscriptionOptions(customer.Id, org, plan, upgrade);
subCreateOptions.EnableAutomaticTax(customer);
var (stripePaymentMethod, paymentMethodType) = IdentifyPaymentMethod(customer, subCreateOptions);
var subscription = await ChargeForNewSubscriptionAsync(org, customer, false,
stripePaymentMethod, paymentMethodType, subCreateOptions, null);
org.GatewaySubscriptionId = subscription.Id;
if (subscription.Status == "incomplete" &&
subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action")
{
org.Enabled = false;
return subscription.LatestInvoice.PaymentIntent.ClientSecret;
}
else
{
org.Enabled = true;
org.ExpirationDate = subscription.CurrentPeriodEnd;
return null;
}
}
private (bool stripePaymentMethod, PaymentMethodType PaymentMethodType) IdentifyPaymentMethod(
Customer customer, SubscriptionCreateOptions subCreateOptions)
{
var stripePaymentMethod = false;
var paymentMethodType = PaymentMethodType.Credit;
var hasBtCustomerId = customer.Metadata.ContainsKey("btCustomerId");
if (hasBtCustomerId)
{
paymentMethodType = PaymentMethodType.PayPal;
}
else
{
if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card")
{
paymentMethodType = PaymentMethodType.Card;
stripePaymentMethod = true;
}
else if (customer.DefaultSource != null)
{
if (customer.DefaultSource is Card || customer.DefaultSource is SourceCard)
{
paymentMethodType = PaymentMethodType.Card;
stripePaymentMethod = true;
}
else if (customer.DefaultSource is BankAccount || customer.DefaultSource is SourceAchDebit)
{
paymentMethodType = PaymentMethodType.BankAccount;
stripePaymentMethod = true;
}
}
else
{
var paymentMethod = GetLatestCardPaymentMethod(customer.Id);
if (paymentMethod != null)
{
paymentMethodType = PaymentMethodType.Card;
stripePaymentMethod = true;
subCreateOptions.DefaultPaymentMethod = paymentMethod.Id;
}
}
}
return (stripePaymentMethod, paymentMethodType);
}
public async Task<string> PurchasePremiumAsync(User user, PaymentMethodType paymentMethodType,
string paymentToken, short additionalStorageGb, TaxInfo taxInfo)
{
if (paymentMethodType != PaymentMethodType.Credit && string.IsNullOrWhiteSpace(paymentToken))
{
throw new BadRequestException("Payment token is required.");
}
if (paymentMethodType == PaymentMethodType.Credit &&
(user.Gateway != GatewayType.Stripe || string.IsNullOrWhiteSpace(user.GatewayCustomerId)))
{
throw new BadRequestException("Your account does not have any credit available.");
}
if (paymentMethodType is PaymentMethodType.BankAccount)
{
throw new GatewayException("Payment method is not supported at this time.");
}
var createdStripeCustomer = false;
Customer customer = null;
Braintree.Customer braintreeCustomer = null;
var stripePaymentMethod = paymentMethodType is PaymentMethodType.Card or PaymentMethodType.BankAccount
or PaymentMethodType.Credit;
string stipeCustomerPaymentMethodId = null;
string stipeCustomerSourceToken = null;
if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken))
{
if (paymentToken.StartsWith("pm_"))
{
stipeCustomerPaymentMethodId = paymentToken;
}
else
{
stipeCustomerSourceToken = paymentToken;
}
}
if (user.Gateway == GatewayType.Stripe && !string.IsNullOrWhiteSpace(user.GatewayCustomerId))
{
if (!string.IsNullOrWhiteSpace(paymentToken))
{
await UpdatePaymentMethodAsync(user, paymentMethodType, paymentToken, taxInfo);
}
try
{
var customerGetOptions = new CustomerGetOptions();
customerGetOptions.AddExpand("tax");
customer = await _stripeAdapter.CustomerGetAsync(user.GatewayCustomerId, customerGetOptions);
}
catch
{
_logger.LogWarning(
"Attempted to get existing customer from Stripe, but customer ID was not found. Attempting to recreate customer...");
}
}
if (customer == null && !string.IsNullOrWhiteSpace(paymentToken))
{
var stripeCustomerMetadata = new Dictionary<string, string>
{
{ "region", _globalSettings.BaseServiceUri.CloudRegion }
};
if (paymentMethodType == PaymentMethodType.PayPal)
{
var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false);
var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest
{
PaymentMethodNonce = paymentToken,
Email = user.Email,
Id = user.BraintreeCustomerIdPrefix() + user.Id.ToString("N").ToLower() + randomSuffix,
CustomFields = new Dictionary<string, string>
{
[user.BraintreeIdField()] = user.Id.ToString(),
[user.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion
}
});
if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0)
{
throw new GatewayException("Failed to create PayPal customer record.");
}
braintreeCustomer = customerResult.Target;
stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id);
}
else if (!stripePaymentMethod)
{
throw new GatewayException("Payment method is not supported at this time.");
}
var customerCreateOptions = new CustomerCreateOptions
{
Description = user.Name,
Email = user.Email,
Metadata = stripeCustomerMetadata,
PaymentMethod = stipeCustomerPaymentMethodId,
Source = stipeCustomerSourceToken,
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
DefaultPaymentMethod = stipeCustomerPaymentMethodId,
CustomFields =
[
new CustomerInvoiceSettingsCustomFieldOptions()
{
Name = user.SubscriberType(),
Value = GetFirstThirtyCharacters(user.SubscriberName()),
}
]
},
Address = new AddressOptions
{
Line1 = string.Empty,
Country = taxInfo.BillingAddressCountry,
PostalCode = taxInfo.BillingAddressPostalCode,
},
};
customerCreateOptions.AddExpand("tax");
customer = await _stripeAdapter.CustomerCreateAsync(customerCreateOptions);
createdStripeCustomer = true;
}
if (customer == null)
{
throw new GatewayException("Could not set up customer payment profile.");
}
var subCreateOptions = new SubscriptionCreateOptions
{
Customer = customer.Id,
Items = [],
Metadata = new Dictionary<string, string>
{
[user.GatewayIdField()] = user.Id.ToString()
}
};
subCreateOptions.Items.Add(new SubscriptionItemOptions
{
Plan = PremiumPlanId,
Quantity = 1
});
if (additionalStorageGb > 0)
{
subCreateOptions.Items.Add(new SubscriptionItemOptions
{
Plan = StoragePlanId,
Quantity = additionalStorageGb
});
}
subCreateOptions.EnableAutomaticTax(customer);
var subscription = await ChargeForNewSubscriptionAsync(user, customer, createdStripeCustomer,
stripePaymentMethod, paymentMethodType, subCreateOptions, braintreeCustomer);
user.Gateway = GatewayType.Stripe;
user.GatewayCustomerId = customer.Id;
user.GatewaySubscriptionId = subscription.Id;
if (subscription.Status == "incomplete" &&
subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action")
{
return subscription.LatestInvoice.PaymentIntent.ClientSecret;
}
user.Premium = true;
user.PremiumExpirationDate = subscription.CurrentPeriodEnd;
return null;
}
private async Task<Subscription> ChargeForNewSubscriptionAsync(ISubscriber subscriber, Customer customer,
bool createdStripeCustomer, bool stripePaymentMethod, PaymentMethodType paymentMethodType,
SubscriptionCreateOptions subCreateOptions, Braintree.Customer braintreeCustomer)
{
var addedCreditToStripeCustomer = false;
Braintree.Transaction braintreeTransaction = null;
var subInvoiceMetadata = new Dictionary<string, string>();
Subscription subscription = null;
try
{
if (!stripePaymentMethod)
{
var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new UpcomingInvoiceOptions
{
Customer = customer.Id,
SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items)
});
if (customer.HasTaxLocationVerified())
{
previewInvoice.AutomaticTax = new InvoiceAutomaticTax { Enabled = true };
}
if (previewInvoice.AmountDue > 0)
{
var braintreeCustomerId = customer.Metadata != null &&
customer.Metadata.ContainsKey("btCustomerId") ? customer.Metadata["btCustomerId"] : null;
if (!string.IsNullOrWhiteSpace(braintreeCustomerId))
{
var btInvoiceAmount = (previewInvoice.AmountDue / 100M);
var transactionResult = await _btGateway.Transaction.SaleAsync(
new Braintree.TransactionRequest
{
Amount = btInvoiceAmount,
CustomerId = braintreeCustomerId,
Options = new Braintree.TransactionOptionsRequest
{
SubmitForSettlement = true,
PayPal = new Braintree.TransactionOptionsPayPalRequest
{
CustomField = $"{subscriber.BraintreeIdField()}:{subscriber.Id},{subscriber.BraintreeCloudRegionField()}:{_globalSettings.BaseServiceUri.CloudRegion}"
}
},
CustomFields = new Dictionary<string, string>
{
[subscriber.BraintreeIdField()] = subscriber.Id.ToString(),
[subscriber.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion
}
});
if (!transactionResult.IsSuccess())
{
throw new GatewayException("Failed to charge PayPal customer.");
}
braintreeTransaction = transactionResult.Target;
subInvoiceMetadata.Add("btTransactionId", braintreeTransaction.Id);
subInvoiceMetadata.Add("btPayPalTransactionId",
braintreeTransaction.PayPalDetails.AuthorizationId);
}
else
{
throw new GatewayException("No payment was able to be collected.");
}
await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
Balance = customer.Balance - previewInvoice.AmountDue
});
addedCreditToStripeCustomer = true;
}
}
else if (paymentMethodType == PaymentMethodType.Credit)
{
var upcomingInvoiceOptions = new UpcomingInvoiceOptions
{
Customer = customer.Id,
SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items),
SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates,
};
upcomingInvoiceOptions.EnableAutomaticTax(customer, null);
var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(upcomingInvoiceOptions);
if (previewInvoice.AmountDue > 0)
{
throw new GatewayException("Your account does not have enough credit available.");
}
}
subCreateOptions.OffSession = true;
subCreateOptions.AddExpand("latest_invoice.payment_intent");
subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions);
if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null)
{
if (subscription.LatestInvoice.PaymentIntent.Status == "requires_payment_method")
{
await _stripeAdapter.SubscriptionCancelAsync(subscription.Id, new SubscriptionCancelOptions());
throw new GatewayException("Payment method was declined.");
}
}
if (!stripePaymentMethod && subInvoiceMetadata.Any())
{
var invoices = await _stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions
{
Subscription = subscription.Id
});
var invoice = invoices?.FirstOrDefault();
if (invoice == null)
{
throw new GatewayException("Invoice not found.");
}
await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new InvoiceUpdateOptions
{
Metadata = subInvoiceMetadata
});
}
return subscription;
}
catch (Exception e)
{
if (customer != null)
{
if (createdStripeCustomer)
{
await _stripeAdapter.CustomerDeleteAsync(customer.Id);
}
else if (addedCreditToStripeCustomer || customer.Balance < 0)
{
await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions
{
Balance = customer.Balance
});
}
}
if (braintreeTransaction != null)
{
await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id);
}
if (braintreeCustomer != null)
{
await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id);
}
if (e is StripeException strEx &&
(strEx.StripeError?.Message?.Contains("cannot be used because it is not verified") ?? false))
{
throw new GatewayException("Bank account is not yet verified.");
}
throw;
}
}
private List<InvoiceSubscriptionItemOptions> ToInvoiceSubscriptionItemOptions(
List<SubscriptionItemOptions> subItemOptions)
{
return subItemOptions.Select(si => new InvoiceSubscriptionItemOptions
{
Plan = si.Plan,
Price = si.Price,
Quantity = si.Quantity,
Id = si.Id
}).ToList();
}
private async Task<string> FinalizeSubscriptionChangeAsync(ISubscriber subscriber, private async Task<string> FinalizeSubscriptionChangeAsync(ISubscriber subscriber,
SubscriptionUpdate subscriptionUpdate, bool invoiceNow = false) SubscriptionUpdate subscriptionUpdate, bool invoiceNow = false)
{ {
@ -1400,7 +711,7 @@ public class StripePaymentService : IPaymentService
new CustomerInvoiceSettingsCustomFieldOptions() new CustomerInvoiceSettingsCustomFieldOptions()
{ {
Name = subscriber.SubscriberType(), Name = subscriber.SubscriberType(),
Value = GetFirstThirtyCharacters(subscriber.SubscriberName()), Value = subscriber.GetFormattedInvoiceName()
} }
] ]
@ -1492,7 +803,7 @@ public class StripePaymentService : IPaymentService
new CustomerInvoiceSettingsCustomFieldOptions() new CustomerInvoiceSettingsCustomFieldOptions()
{ {
Name = subscriber.SubscriberType(), Name = subscriber.SubscriberType(),
Value = GetFirstThirtyCharacters(subscriber.SubscriberName()) Value = subscriber.GetFormattedInvoiceName()
} }
] ]
}, },
@ -1560,7 +871,7 @@ public class StripePaymentService : IPaymentService
var customer = await GetCustomerAsync(subscriber.GatewayCustomerId, GetCustomerPaymentOptions()); var customer = await GetCustomerAsync(subscriber.GatewayCustomerId, GetCustomerPaymentOptions());
var billingInfo = new BillingInfo var billingInfo = new BillingInfo
{ {
Balance = GetBillingBalance(customer), Balance = customer.GetBillingBalance(),
PaymentSource = await GetBillingPaymentSourceAsync(customer) PaymentSource = await GetBillingPaymentSourceAsync(customer)
}; };
@ -1768,27 +1079,6 @@ public class StripePaymentService : IPaymentService
new SecretsManagerSubscribeUpdate(org, plan, additionalSmSeats, additionalServiceAccount), new SecretsManagerSubscribeUpdate(org, plan, additionalSmSeats, additionalServiceAccount),
true); true);
public async Task<bool> RisksSubscriptionFailure(Organization organization)
{
var subscriptionInfo = await GetSubscriptionAsync(organization);
if (subscriptionInfo.Subscription is not
{
Status: "active" or "trialing" or "past_due",
CollectionMethod: "charge_automatically"
}
|| subscriptionInfo.UpcomingInvoice == null)
{
return false;
}
var customer = await GetCustomerAsync(organization.GatewayCustomerId, GetCustomerPaymentOptions());
var paymentSource = await GetBillingPaymentSourceAsync(customer);
return paymentSource == null;
}
public async Task<bool> HasSecretsManagerStandalone(Organization organization) public async Task<bool> HasSecretsManagerStandalone(Organization organization)
{ {
if (string.IsNullOrEmpty(organization.GatewayCustomerId)) if (string.IsNullOrEmpty(organization.GatewayCustomerId))
@ -1801,7 +1091,7 @@ public class StripePaymentService : IPaymentService
return customer?.Discount?.Coupon?.Id == SecretsManagerStandaloneDiscountId; return customer?.Discount?.Coupon?.Id == SecretsManagerStandaloneDiscountId;
} }
public async Task<(DateTime?, DateTime?)> GetSuspensionDateAsync(Subscription subscription) private async Task<(DateTime?, DateTime?)> GetSuspensionDateAsync(Subscription subscription)
{ {
if (subscription.Status is not "past_due" && subscription.Status is not "unpaid") if (subscription.Status is not "past_due" && subscription.Status is not "unpaid")
{ {
@ -2117,11 +1407,6 @@ public class StripePaymentService : IPaymentService
return cardPaymentMethods.OrderByDescending(m => m.Created).FirstOrDefault(); return cardPaymentMethods.OrderByDescending(m => m.Created).FirstOrDefault();
} }
private decimal GetBillingBalance(Customer customer)
{
return customer != null ? customer.Balance / 100M : default;
}
private async Task<BillingInfo.BillingSource> GetBillingPaymentSourceAsync(Customer customer) private async Task<BillingInfo.BillingSource> GetBillingPaymentSourceAsync(Customer customer)
{ {
if (customer == null) if (customer == null)
@ -2252,18 +1537,4 @@ public class StripePaymentService : IPaymentService
throw new GatewayException("Failed to retrieve current invoices", exception); throw new GatewayException("Failed to retrieve current invoices", exception);
} }
} }
// We are taking only first 30 characters of the SubscriberName because stripe provide
// for 30 characters for custom_fields,see the link: https://stripe.com/docs/api/invoices/create
private static string GetFirstThirtyCharacters(string subscriberName)
{
if (string.IsNullOrWhiteSpace(subscriberName))
{
return string.Empty;
}
return subscriberName.Length <= 30
? subscriberName
: subscriberName[..30];
}
} }

View File

@ -1218,10 +1218,7 @@ public class UserService : UserManager<User>, IUserService, IDisposable
? new UserLicense(user, _licenseService) ? new UserLicense(user, _licenseService)
: new UserLicense(user, subscriptionInfo, _licenseService); : new UserLicense(user, subscriptionInfo, _licenseService);
if (_featureService.IsEnabled(FeatureFlagKeys.SelfHostLicenseRefactor)) userLicense.Token = await _licenseService.CreateUserTokenAsync(user, subscriptionInfo);
{
userLicense.Token = await _licenseService.CreateUserTokenAsync(user, subscriptionInfo);
}
return userLicense; return userLicense;
} }

View File

@ -324,7 +324,7 @@ public class NoopMailService : IMailService
return Task.FromResult(0); return Task.FromResult(0);
} }
public Task SendBulkSecurityTaskNotificationsAsync(string orgName, IEnumerable<UserSecurityTasksCount> securityTaskNotificaitons) public Task SendBulkSecurityTaskNotificationsAsync(Organization org, IEnumerable<UserSecurityTasksCount> securityTaskNotifications, IEnumerable<string> adminOwnerEmails)
{ {
return Task.FromResult(0); return Task.FromResult(0);
} }

View File

@ -54,12 +54,11 @@ public class ImportCiphersCommand : IImportCiphersCommand
public async Task ImportIntoIndividualVaultAsync( public async Task ImportIntoIndividualVaultAsync(
List<Folder> folders, List<Folder> folders,
List<CipherDetails> ciphers, List<CipherDetails> ciphers,
IEnumerable<KeyValuePair<int, int>> folderRelationships) IEnumerable<KeyValuePair<int, int>> folderRelationships,
Guid importingUserId)
{ {
var userId = folders.FirstOrDefault()?.UserId ?? ciphers.FirstOrDefault()?.UserId;
// Make sure the user can save new ciphers to their personal vault // Make sure the user can save new ciphers to their personal vault
var anyPersonalOwnershipPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(userId.Value, PolicyType.PersonalOwnership); var anyPersonalOwnershipPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(importingUserId, PolicyType.PersonalOwnership);
if (anyPersonalOwnershipPolicies) if (anyPersonalOwnershipPolicies)
{ {
throw new BadRequestException("You cannot import items into your personal vault because you are " + throw new BadRequestException("You cannot import items into your personal vault because you are " +
@ -76,7 +75,7 @@ public class ImportCiphersCommand : IImportCiphersCommand
} }
} }
var userfoldersIds = (await _folderRepository.GetManyByUserIdAsync(userId ?? Guid.Empty)).Select(f => f.Id).ToList(); var userfoldersIds = (await _folderRepository.GetManyByUserIdAsync(importingUserId)).Select(f => f.Id).ToList();
//Assign id to the ones that don't exist in DB //Assign id to the ones that don't exist in DB
//Need to keep the list order to create the relationships //Need to keep the list order to create the relationships
@ -109,10 +108,7 @@ public class ImportCiphersCommand : IImportCiphersCommand
await _cipherRepository.CreateAsync(ciphers, newFolders); await _cipherRepository.CreateAsync(ciphers, newFolders);
// push // push
if (userId.HasValue) await _pushService.PushSyncVaultAsync(importingUserId);
{
await _pushService.PushSyncVaultAsync(userId.Value);
}
} }
public async Task ImportIntoOrganizationalVaultAsync( public async Task ImportIntoOrganizationalVaultAsync(

View File

@ -7,7 +7,7 @@ namespace Bit.Core.Tools.ImportFeatures.Interfaces;
public interface IImportCiphersCommand public interface IImportCiphersCommand
{ {
Task ImportIntoIndividualVaultAsync(List<Folder> folders, List<CipherDetails> ciphers, Task ImportIntoIndividualVaultAsync(List<Folder> folders, List<CipherDetails> ciphers,
IEnumerable<KeyValuePair<int, int>> folderRelationships); IEnumerable<KeyValuePair<int, int>> folderRelationships, Guid importingUserId);
Task ImportIntoOrganizationalVaultAsync(List<Collection> collections, List<CipherDetails> ciphers, Task ImportIntoOrganizationalVaultAsync(List<Collection> collections, List<CipherDetails> ciphers,
IEnumerable<KeyValuePair<int, int>> collectionRelationships, Guid importingUserId); IEnumerable<KeyValuePair<int, int>> collectionRelationships, Guid importingUserId);

View File

@ -326,14 +326,14 @@ public class SendService : ISendService
return; return;
} }
var sendPolicyRequirement = await _policyRequirementQuery.GetAsync<SendPolicyRequirement>(userId.Value); var disableSendRequirement = await _policyRequirementQuery.GetAsync<DisableSendPolicyRequirement>(userId.Value);
if (disableSendRequirement.DisableSend)
if (sendPolicyRequirement.DisableSend)
{ {
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send."); throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
} }
if (sendPolicyRequirement.DisableHideEmail && send.HideEmail.GetValueOrDefault()) var sendOptionsRequirement = await _policyRequirementQuery.GetAsync<SendOptionsPolicyRequirement>(userId.Value);
if (sendOptionsRequirement.DisableHideEmail && send.HideEmail.GetValueOrDefault())
{ {
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send."); throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
} }

View File

@ -158,21 +158,4 @@ public static class StaticStore
public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) => public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) =>
SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType); SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType);
/// <summary>
/// Determines if the stripe plan id is an addon item by checking if the provided stripe plan id
/// matches either the <see cref="Plan.PasswordManagerPlanFeatures.StripeStoragePlanId"/> or <see cref="Plan.SecretsManagerPlanFeatures.StripeServiceAccountPlanId"/>
/// in any <see cref="Plans"/>.
/// </summary>
/// <param name="stripePlanId"></param>
/// <returns>
/// True if the stripePlanId is a addon product, false otherwise
/// </returns>
public static bool IsAddonSubscriptionItem(string stripePlanId)
{
// TODO: PRICING -> https://bitwarden.atlassian.net/browse/PM-16844
return Plans.Any(p =>
p.PasswordManager.StripeStoragePlanId == stripePlanId ||
(p.SecretsManager?.StripeServiceAccountPlanId == stripePlanId));
}
} }

View File

@ -17,19 +17,22 @@ public class CreateManyTaskNotificationsCommand : ICreateManyTaskNotificationsCo
private readonly IMailService _mailService; private readonly IMailService _mailService;
private readonly ICreateNotificationCommand _createNotificationCommand; private readonly ICreateNotificationCommand _createNotificationCommand;
private readonly IPushNotificationService _pushNotificationService; private readonly IPushNotificationService _pushNotificationService;
private readonly IOrganizationUserRepository _organizationUserRepository;
public CreateManyTaskNotificationsCommand( public CreateManyTaskNotificationsCommand(
IGetSecurityTasksNotificationDetailsQuery getSecurityTasksNotificationDetailsQuery, IGetSecurityTasksNotificationDetailsQuery getSecurityTasksNotificationDetailsQuery,
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
IMailService mailService, IMailService mailService,
ICreateNotificationCommand createNotificationCommand, ICreateNotificationCommand createNotificationCommand,
IPushNotificationService pushNotificationService) IPushNotificationService pushNotificationService,
IOrganizationUserRepository organizationUserRepository)
{ {
_getSecurityTasksNotificationDetailsQuery = getSecurityTasksNotificationDetailsQuery; _getSecurityTasksNotificationDetailsQuery = getSecurityTasksNotificationDetailsQuery;
_organizationRepository = organizationRepository; _organizationRepository = organizationRepository;
_mailService = mailService; _mailService = mailService;
_createNotificationCommand = createNotificationCommand; _createNotificationCommand = createNotificationCommand;
_pushNotificationService = pushNotificationService; _pushNotificationService = pushNotificationService;
_organizationUserRepository = organizationUserRepository;
} }
public async Task CreateAsync(Guid orgId, IEnumerable<SecurityTask> securityTasks) public async Task CreateAsync(Guid orgId, IEnumerable<SecurityTask> securityTasks)
@ -45,8 +48,11 @@ public class CreateManyTaskNotificationsCommand : ICreateManyTaskNotificationsCo
}).ToList(); }).ToList();
var organization = await _organizationRepository.GetByIdAsync(orgId); var organization = await _organizationRepository.GetByIdAsync(orgId);
var orgAdminEmails = await _organizationUserRepository.GetManyDetailsByRoleAsync(orgId, OrganizationUserType.Admin);
var orgOwnerEmails = await _organizationUserRepository.GetManyDetailsByRoleAsync(orgId, OrganizationUserType.Owner);
var orgAdminAndOwnerEmails = orgAdminEmails.Concat(orgOwnerEmails).Select(x => x.Email).Distinct().ToList();
await _mailService.SendBulkSecurityTaskNotificationsAsync(organization.Name, userTaskCount); await _mailService.SendBulkSecurityTaskNotificationsAsync(organization, userTaskCount, orgAdminAndOwnerEmails);
// Break securityTaskCiphers into separate lists by user Id // Break securityTaskCiphers into separate lists by user Id
var securityTaskCiphersByUser = securityTaskCiphers.GroupBy(x => x.UserId) var securityTaskCiphersByUser = securityTaskCiphers.GroupBy(x => x.UserId)

View File

@ -13,7 +13,9 @@ using Bit.Core.Tools.Models.Business;
using Bit.Core.Tools.Services; using Bit.Core.Tools.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
using Bit.Core.Vault.Enums;
using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Models.Data;
using Bit.Core.Vault.Queries;
using Bit.Core.Vault.Repositories; using Bit.Core.Vault.Repositories;
namespace Bit.Core.Vault.Services; namespace Bit.Core.Vault.Services;
@ -38,6 +40,7 @@ public class CipherService : ICipherService
private const long _fileSizeLeeway = 1024L * 1024L; // 1MB private const long _fileSizeLeeway = 1024L * 1024L; // 1MB
private readonly IReferenceEventService _referenceEventService; private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly IGetCipherPermissionsForUserQuery _getCipherPermissionsForUserQuery;
public CipherService( public CipherService(
ICipherRepository cipherRepository, ICipherRepository cipherRepository,
@ -54,7 +57,8 @@ public class CipherService : ICipherService
IPolicyService policyService, IPolicyService policyService,
GlobalSettings globalSettings, GlobalSettings globalSettings,
IReferenceEventService referenceEventService, IReferenceEventService referenceEventService,
ICurrentContext currentContext) ICurrentContext currentContext,
IGetCipherPermissionsForUserQuery getCipherPermissionsForUserQuery)
{ {
_cipherRepository = cipherRepository; _cipherRepository = cipherRepository;
_folderRepository = folderRepository; _folderRepository = folderRepository;
@ -71,6 +75,7 @@ public class CipherService : ICipherService
_globalSettings = globalSettings; _globalSettings = globalSettings;
_referenceEventService = referenceEventService; _referenceEventService = referenceEventService;
_currentContext = currentContext; _currentContext = currentContext;
_getCipherPermissionsForUserQuery = getCipherPermissionsForUserQuery;
} }
public async Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, public async Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate,
@ -161,6 +166,7 @@ public class CipherService : ICipherService
{ {
ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate);
cipher.RevisionDate = DateTime.UtcNow; cipher.RevisionDate = DateTime.UtcNow;
await ValidateViewPasswordUserAsync(cipher);
await _cipherRepository.ReplaceAsync(cipher); await _cipherRepository.ReplaceAsync(cipher);
await _eventService.LogCipherEventAsync(cipher, Bit.Core.Enums.EventType.Cipher_Updated); await _eventService.LogCipherEventAsync(cipher, Bit.Core.Enums.EventType.Cipher_Updated);
@ -966,4 +972,32 @@ public class CipherService : ICipherService
ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate);
} }
private async Task ValidateViewPasswordUserAsync(Cipher cipher)
{
if (cipher.Type != CipherType.Login || cipher.Data == null || !cipher.OrganizationId.HasValue)
{
return;
}
var existingCipher = await _cipherRepository.GetByIdAsync(cipher.Id);
if (existingCipher == null) return;
var cipherPermissions = await _getCipherPermissionsForUserQuery.GetByOrganization(cipher.OrganizationId.Value);
// Check if user is a "hidden password" user
if (!cipherPermissions.TryGetValue(cipher.Id, out var permission) || !(permission.ViewPassword && permission.Edit))
{
// "hidden password" users may not add cipher key encryption
if (existingCipher.Key == null && cipher.Key != null)
{
throw new BadRequestException("You do not have permission to add cipher key encryption.");
}
// "hidden password" users may not change passwords, TOTP codes, or passkeys, so we need to set them back to the original values
var existingCipherData = JsonSerializer.Deserialize<CipherLoginData>(existingCipher.Data);
var newCipherData = JsonSerializer.Deserialize<CipherLoginData>(cipher.Data);
newCipherData.Fido2Credentials = existingCipherData.Fido2Credentials;
newCipherData.Totp = existingCipherData.Totp;
newCipherData.Password = existingCipherData.Password;
cipher.Data = JsonSerializer.Serialize(newCipherData);
}
}
} }

View File

@ -7,7 +7,7 @@
<PropertyGroup Condition=" '$(RunConfiguration)' == 'Icons' " /> <PropertyGroup Condition=" '$(RunConfiguration)' == 'Icons' " />
<ItemGroup> <ItemGroup>
<PackageReference Include="AngleSharp" Version="1.1.2" /> <PackageReference Include="AngleSharp" Version="1.2.0" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@ -250,6 +250,11 @@ public class DeviceValidator(
var customResponse = new Dictionary<string, object>(); var customResponse = new Dictionary<string, object>();
switch (errorType) switch (errorType)
{ {
/*
* The ErrorMessage is brittle and is used to control the flow in the clients. Do not change them without updating the client as well.
* There is a backwards compatibility issue as well: if you make a change on the clients then ensure that they are backwards
* compatible.
*/
case DeviceValidationResultType.InvalidUser: case DeviceValidationResultType.InvalidUser:
result.ErrorDescription = "Invalid user"; result.ErrorDescription = "Invalid user";
customResponse.Add("ErrorModel", new ErrorResponseModel("invalid user")); customResponse.Add("ErrorModel", new ErrorResponseModel("invalid user"));

View File

@ -1,5 +1,4 @@
using System.Text.Json; using System.Text.Json;
using Bit.Core;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Enums; using Bit.Core.Auth.Enums;
using Bit.Core.Auth.Identity.TokenProviders; using Bit.Core.Auth.Identity.TokenProviders;
@ -155,12 +154,9 @@ public class TwoFactorAuthenticationValidator(
return false; return false;
} }
if (_featureService.IsEnabled(FeatureFlagKeys.RecoveryCodeLogin)) if (type is TwoFactorProviderType.RecoveryCode)
{ {
if (type is TwoFactorProviderType.RecoveryCode) return await _userService.RecoverTwoFactorAsync(user, token);
{
return await _userService.RecoverTwoFactorAsync(user, token);
}
} }
// These cases we want to always return false, U2f is deprecated and OrganizationDuo // These cases we want to always return false, U2f is deprecated and OrganizationDuo

View File

@ -563,8 +563,8 @@ public class OrganizationUserRepository : Repository<OrganizationUser, Guid>, IO
await using var connection = new SqlConnection(ConnectionString); await using var connection = new SqlConnection(ConnectionString);
await connection.ExecuteAsync( await connection.ExecuteAsync(
"[dbo].[OrganizationUser_SetStatusForUsersById]", "[dbo].[OrganizationUser_SetStatusForUsersByGuidIdArray]",
new { OrganizationUserIds = JsonSerializer.Serialize(organizationUserIds), Status = OrganizationUserStatusType.Revoked }, new { OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP(), Status = OrganizationUserStatusType.Revoked },
commandType: CommandType.StoredProcedure); commandType: CommandType.StoredProcedure);
} }

View File

@ -290,21 +290,33 @@ public class OrganizationRepository : Repository<Core.AdminConsole.Entities.Orga
public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetByVerifiedUserEmailDomainAsync(Guid userId) public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetByVerifiedUserEmailDomainAsync(Guid userId)
{ {
using (var scope = ServiceScopeFactory.CreateScope()) using var scope = ServiceScopeFactory.CreateScope();
{
var dbContext = GetDatabaseContext(scope);
var query = from u in dbContext.Users var dbContext = GetDatabaseContext(scope);
join ou in dbContext.OrganizationUsers on u.Id equals ou.UserId
join o in dbContext.Organizations on ou.OrganizationId equals o.Id var userQuery = from u in dbContext.Users
join od in dbContext.OrganizationDomains on ou.OrganizationId equals od.OrganizationId
where u.Id == userId where u.Id == userId
&& od.VerifiedDate != null select u;
&& u.Email.ToLower().EndsWith("@" + od.DomainName.ToLower())
select o;
return await query.ToArrayAsync(); var user = await userQuery.FirstOrDefaultAsync();
if (user is null)
{
return new List<Core.AdminConsole.Entities.Organization>();
} }
var userWithDomain = new { UserId = user.Id, EmailDomain = user.Email.Split('@').Last() };
var query = from o in dbContext.Organizations
join ou in dbContext.OrganizationUsers on o.Id equals ou.OrganizationId
join od in dbContext.OrganizationDomains on ou.OrganizationId equals od.OrganizationId
where ou.UserId == userWithDomain.UserId &&
od.DomainName == userWithDomain.EmailDomain &&
od.VerifiedDate != null &&
o.Enabled == true
select o;
return await query.ToArrayAsync();
} }
public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetAddableToProviderByUserIdAsync( public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetAddableToProviderByUserIdAsync(

View File

@ -6,6 +6,7 @@
@RequestDeviceIdentifier NVARCHAR(50), @RequestDeviceIdentifier NVARCHAR(50),
@RequestDeviceType TINYINT, @RequestDeviceType TINYINT,
@RequestIpAddress VARCHAR(50), @RequestIpAddress VARCHAR(50),
@RequestCountryName NVARCHAR(200),
@ResponseDeviceId UNIQUEIDENTIFIER, @ResponseDeviceId UNIQUEIDENTIFIER,
@AccessCode VARCHAR(25), @AccessCode VARCHAR(25),
@PublicKey VARCHAR(MAX), @PublicKey VARCHAR(MAX),
@ -20,7 +21,7 @@ BEGIN
SET NOCOUNT ON SET NOCOUNT ON
INSERT INTO [dbo].[AuthRequest] INSERT INTO [dbo].[AuthRequest]
( (
[Id], [Id],
[UserId], [UserId],
[OrganizationId], [OrganizationId],
@ -28,6 +29,7 @@ BEGIN
[RequestDeviceIdentifier], [RequestDeviceIdentifier],
[RequestDeviceType], [RequestDeviceType],
[RequestIpAddress], [RequestIpAddress],
[RequestCountryName],
[ResponseDeviceId], [ResponseDeviceId],
[AccessCode], [AccessCode],
[PublicKey], [PublicKey],
@ -37,24 +39,25 @@ BEGIN
[CreationDate], [CreationDate],
[ResponseDate], [ResponseDate],
[AuthenticationDate] [AuthenticationDate]
) )
VALUES VALUES
( (
@Id, @Id,
@UserId, @UserId,
@OrganizationId, @OrganizationId,
@Type, @Type,
@RequestDeviceIdentifier, @RequestDeviceIdentifier,
@RequestDeviceType, @RequestDeviceType,
@RequestIpAddress, @RequestIpAddress,
@ResponseDeviceId, @RequestCountryName,
@AccessCode, @ResponseDeviceId,
@PublicKey, @AccessCode,
@Key, @PublicKey,
@MasterPasswordHash, @Key,
@Approved, @MasterPasswordHash,
@CreationDate, @Approved,
@ResponseDate, @CreationDate,
@AuthenticationDate @ResponseDate,
@AuthenticationDate
) )
END END

View File

@ -2,10 +2,11 @@
@Id UNIQUEIDENTIFIER OUTPUT, @Id UNIQUEIDENTIFIER OUTPUT,
@UserId UNIQUEIDENTIFIER, @UserId UNIQUEIDENTIFIER,
@OrganizationId UNIQUEIDENTIFIER = NULL, @OrganizationId UNIQUEIDENTIFIER = NULL,
@Type SMALLINT, @Type SMALLINT,
@RequestDeviceIdentifier NVARCHAR(50), @RequestDeviceIdentifier NVARCHAR(50),
@RequestDeviceType SMALLINT, @RequestDeviceType SMALLINT,
@RequestIpAddress VARCHAR(50), @RequestIpAddress VARCHAR(50),
@RequestCountryName NVARCHAR(200),
@ResponseDeviceId UNIQUEIDENTIFIER, @ResponseDeviceId UNIQUEIDENTIFIER,
@AccessCode VARCHAR(25), @AccessCode VARCHAR(25),
@PublicKey VARCHAR(MAX), @PublicKey VARCHAR(MAX),
@ -14,29 +15,30 @@
@Approved BIT, @Approved BIT,
@CreationDate DATETIME2 (7), @CreationDate DATETIME2 (7),
@ResponseDate DATETIME2 (7), @ResponseDate DATETIME2 (7),
@AuthenticationDate DATETIME2 (7) @AuthenticationDate DATETIME2 (7)
AS AS
BEGIN BEGIN
SET NOCOUNT ON SET NOCOUNT ON
UPDATE UPDATE
[dbo].[AuthRequest] [dbo].[AuthRequest]
SET SET
[UserId] = @UserId, [UserId] = @UserId,
[Type] = @Type, [Type] = @Type,
[OrganizationId] = @OrganizationId, [OrganizationId] = @OrganizationId,
[RequestDeviceIdentifier] = @RequestDeviceIdentifier, [RequestDeviceIdentifier] = @RequestDeviceIdentifier,
[RequestDeviceType] = @RequestDeviceType, [RequestDeviceType] = @RequestDeviceType,
[RequestIpAddress] = @RequestIpAddress, [RequestIpAddress] = @RequestIpAddress,
[ResponseDeviceId] = @ResponseDeviceId, [RequestCountryName] = @RequestCountryName,
[AccessCode] = @AccessCode, [ResponseDeviceId] = @ResponseDeviceId,
[PublicKey] = @PublicKey, [AccessCode] = @AccessCode,
[Key] = @Key, [PublicKey] = @PublicKey,
[MasterPasswordHash] = @MasterPasswordHash, [Key] = @Key,
[Approved] = @Approved, [MasterPasswordHash] = @MasterPasswordHash,
[CreationDate] = @CreationDate, [Approved] = @Approved,
[ResponseDate] = @ResponseDate, [CreationDate] = @CreationDate,
[AuthenticationDate] = @AuthenticationDate [ResponseDate] = @ResponseDate,
WHERE [AuthenticationDate] = @AuthenticationDate
[Id] = @Id WHERE
[Id] = @Id
END END

View File

@ -10,6 +10,7 @@ BEGIN
[RequestDeviceIdentifier] = ARI.[RequestDeviceIdentifier], [RequestDeviceIdentifier] = ARI.[RequestDeviceIdentifier],
[RequestDeviceType] = ARI.[RequestDeviceType], [RequestDeviceType] = ARI.[RequestDeviceType],
[RequestIpAddress] = ARI.[RequestIpAddress], [RequestIpAddress] = ARI.[RequestIpAddress],
[RequestCountryName] = ARI.[RequestCountryName],
[ResponseDeviceId] = ARI.[ResponseDeviceId], [ResponseDeviceId] = ARI.[ResponseDeviceId],
[AccessCode] = ARI.[AccessCode], [AccessCode] = ARI.[AccessCode],
[PublicKey] = ARI.[PublicKey], [PublicKey] = ARI.[PublicKey],
@ -22,7 +23,7 @@ BEGIN
[OrganizationId] = ARI.[OrganizationId] [OrganizationId] = ARI.[OrganizationId]
FROM FROM
[dbo].[AuthRequest] AR [dbo].[AuthRequest] AR
INNER JOIN INNER JOIN
OPENJSON(@jsonData) OPENJSON(@jsonData)
WITH ( WITH (
Id UNIQUEIDENTIFIER '$.Id', Id UNIQUEIDENTIFIER '$.Id',
@ -31,6 +32,7 @@ BEGIN
RequestDeviceIdentifier NVARCHAR(50) '$.RequestDeviceIdentifier', RequestDeviceIdentifier NVARCHAR(50) '$.RequestDeviceIdentifier',
RequestDeviceType SMALLINT '$.RequestDeviceType', RequestDeviceType SMALLINT '$.RequestDeviceType',
RequestIpAddress VARCHAR(50) '$.RequestIpAddress', RequestIpAddress VARCHAR(50) '$.RequestIpAddress',
RequestCountryName NVARCHAR(200) '$.RequestCountryName',
ResponseDeviceId UNIQUEIDENTIFIER '$.ResponseDeviceId', ResponseDeviceId UNIQUEIDENTIFIER '$.ResponseDeviceId',
AccessCode VARCHAR(25) '$.AccessCode', AccessCode VARCHAR(25) '$.AccessCode',
PublicKey VARCHAR(MAX) '$.PublicKey', PublicKey VARCHAR(MAX) '$.PublicKey',

View File

@ -15,11 +15,11 @@
[ResponseDate] DATETIME2 (7) NULL, [ResponseDate] DATETIME2 (7) NULL,
[AuthenticationDate] DATETIME2 (7) NULL, [AuthenticationDate] DATETIME2 (7) NULL,
[OrganizationId] UNIQUEIDENTIFIER NULL, [OrganizationId] UNIQUEIDENTIFIER NULL,
[RequestCountryName] NVARCHAR(200) NULL,
CONSTRAINT [PK_AuthRequest] PRIMARY KEY CLUSTERED ([Id] ASC), CONSTRAINT [PK_AuthRequest] PRIMARY KEY CLUSTERED ([Id] ASC),
CONSTRAINT [FK_AuthRequest_User] FOREIGN KEY ([UserId]) REFERENCES [dbo].[User] ([Id]), CONSTRAINT [FK_AuthRequest_User] FOREIGN KEY ([UserId]) REFERENCES [dbo].[User] ([Id]),
CONSTRAINT [FK_AuthRequest_ResponseDevice] FOREIGN KEY ([ResponseDeviceId]) REFERENCES [dbo].[Device] ([Id]), CONSTRAINT [FK_AuthRequest_ResponseDevice] FOREIGN KEY ([ResponseDeviceId]) REFERENCES [dbo].[Device] ([Id]),
CONSTRAINT [FK_AuthRequest_Organization] FOREIGN KEY ([OrganizationId]) REFERENCES [dbo].[Organization] ([Id]) CONSTRAINT [FK_AuthRequest_Organization] FOREIGN KEY ([OrganizationId]) REFERENCES [dbo].[Organization] ([Id])
); );
GO GO

View File

@ -0,0 +1,14 @@
CREATE PROCEDURE [dbo].[OrganizationUser_SetStatusForUsersByGuidIdArray]
@OrganizationUserIds AS [dbo].[GuidIdArray] READONLY,
@Status SMALLINT
AS
BEGIN
SET NOCOUNT ON
UPDATE OU
SET OU.[Status] = @Status
FROM [dbo].[OrganizationUser] OU
INNER JOIN @OrganizationUserIds OUI ON OUI.[Id] = OU.[Id]
EXEC [dbo].[User_BumpAccountRevisionDateByOrganizationUserIds] @OrganizationUserIds
END

View File

@ -4,12 +4,19 @@ AS
BEGIN BEGIN
SET NOCOUNT ON; SET NOCOUNT ON;
WITH CTE_User AS (
SELECT
U.*,
SUBSTRING(U.Email, CHARINDEX('@', U.Email) + 1, LEN(U.Email)) AS EmailDomain
FROM dbo.[UserView] U
WHERE U.[Id] = @UserId
)
SELECT O.* SELECT O.*
FROM [dbo].[UserView] U FROM CTE_User CU
INNER JOIN [dbo].[OrganizationUserView] OU ON U.[Id] = OU.[UserId] INNER JOIN dbo.[OrganizationUserView] OU ON CU.[Id] = OU.[UserId]
INNER JOIN [dbo].[OrganizationView] O ON OU.[OrganizationId] = O.[Id] INNER JOIN dbo.[OrganizationView] O ON OU.[OrganizationId] = O.[Id]
INNER JOIN [dbo].[OrganizationDomainView] OD ON OU.[OrganizationId] = OD.[OrganizationId] INNER JOIN dbo.[OrganizationDomainView] OD ON OU.[OrganizationId] = OD.[OrganizationId]
WHERE U.[Id] = @UserId WHERE OD.[VerifiedDate] IS NOT NULL
AND OD.[VerifiedDate] IS NOT NULL AND CU.EmailDomain = OD.[DomainName]
AND U.[Email] LIKE '%@' + OD.[DomainName]; AND O.[Enabled] = 1
END END

View File

@ -22,3 +22,8 @@ CREATE NONCLUSTERED INDEX [IX_OrganizationDomain_VerifiedDate]
ON [dbo].[OrganizationDomain] ([VerifiedDate]) ON [dbo].[OrganizationDomain] ([VerifiedDate])
INCLUDE ([OrganizationId],[DomainName]); INCLUDE ([OrganizationId],[DomainName]);
GO GO
CREATE NONCLUSTERED INDEX [IX_OrganizationDomain_DomainNameVerifiedDateOrganizationId]
ON [dbo].[OrganizationDomain] ([DomainName],[VerifiedDate])
INCLUDE ([OrganizationId])
GO

View File

@ -7,6 +7,8 @@ using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities; using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Repositories; using Bit.Core.Auth.Repositories;
@ -424,4 +426,93 @@ public class OrganizationUsersControllerTests
.GetManyDetailsByOrganizationAsync(organizationAbility.Id, Arg.Any<bool>(), Arg.Any<bool>()) .GetManyDetailsByOrganizationAsync(organizationAbility.Id, Arg.Any<bool>(), Arg.Any<bool>())
.Returns(organizationUsers); .Returns(organizationUsers);
} }
[Theory]
[BitAutoData]
public async Task Accept_WhenOrganizationUsePoliciesIsEnabledAndResetPolicyIsEnabled_WithPolicyRequirementsEnabled_ShouldHandleResetPassword(Guid orgId, Guid orgUserId,
OrganizationUserAcceptRequestModel model, User user, SutProvider<OrganizationUsersController> sutProvider)
{
// Arrange
var applicationCacheService = sutProvider.GetDependency<IApplicationCacheService>();
applicationCacheService.GetOrganizationAbilityAsync(orgId).Returns(new OrganizationAbility { UsePolicies = true });
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var policy = new Policy
{
Enabled = true,
Data = CoreHelpers.ClassToJsonData(new ResetPasswordDataModel { AutoEnrollEnabled = true, }),
};
var userService = sutProvider.GetDependency<IUserService>();
userService.GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user);
var policyRequirementQuery = sutProvider.GetDependency<IPolicyRequirementQuery>();
var policyRepository = sutProvider.GetDependency<IPolicyRepository>();
var policyRequirement = new ResetPasswordPolicyRequirement { AutoEnrollOrganizations = [orgId] };
policyRequirementQuery.GetAsync<ResetPasswordPolicyRequirement>(user.Id).Returns(policyRequirement);
// Act
await sutProvider.Sut.Accept(orgId, orgUserId, model);
// Assert
await sutProvider.GetDependency<IAcceptOrgUserCommand>().Received(1)
.AcceptOrgUserByEmailTokenAsync(orgUserId, user, model.Token, userService);
await sutProvider.GetDependency<IOrganizationService>().Received(1)
.UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id);
await userService.Received(1).GetUserByPrincipalAsync(default);
await applicationCacheService.Received(0).GetOrganizationAbilityAsync(orgId);
await policyRepository.Received(0).GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword);
await policyRequirementQuery.Received(1).GetAsync<ResetPasswordPolicyRequirement>(user.Id);
Assert.True(policyRequirement.AutoEnrollEnabled(orgId));
}
[Theory]
[BitAutoData]
public async Task Accept_WithInvalidModelResetPasswordKey_WithPolicyRequirementsEnabled_ThrowsBadRequestException(Guid orgId, Guid orgUserId,
OrganizationUserAcceptRequestModel model, User user, SutProvider<OrganizationUsersController> sutProvider)
{
// Arrange
model.ResetPasswordKey = " ";
var applicationCacheService = sutProvider.GetDependency<IApplicationCacheService>();
applicationCacheService.GetOrganizationAbilityAsync(orgId).Returns(new OrganizationAbility { UsePolicies = true });
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
var policy = new Policy
{
Enabled = true,
Data = CoreHelpers.ClassToJsonData(new ResetPasswordDataModel { AutoEnrollEnabled = true, }),
};
var userService = sutProvider.GetDependency<IUserService>();
userService.GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user);
var policyRepository = sutProvider.GetDependency<IPolicyRepository>();
var policyRequirementQuery = sutProvider.GetDependency<IPolicyRequirementQuery>();
var policyRequirement = new ResetPasswordPolicyRequirement { AutoEnrollOrganizations = [orgId] };
policyRequirementQuery.GetAsync<ResetPasswordPolicyRequirement>(user.Id).Returns(policyRequirement);
// Act
var exception = await Assert.ThrowsAsync<BadRequestException>(() =>
sutProvider.Sut.Accept(orgId, orgUserId, model));
// Assert
await sutProvider.GetDependency<IAcceptOrgUserCommand>().Received(0)
.AcceptOrgUserByEmailTokenAsync(orgUserId, user, model.Token, userService);
await sutProvider.GetDependency<IOrganizationService>().Received(0)
.UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id);
await userService.Received(1).GetUserByPrincipalAsync(default);
await applicationCacheService.Received(0).GetOrganizationAbilityAsync(orgId);
await policyRepository.Received(0).GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword);
await policyRequirementQuery.Received(1).GetAsync<ResetPasswordPolicyRequirement>(user.Id);
Assert.Equal("Master Password reset is required, but not provided.", exception.Message);
}
} }

View File

@ -4,12 +4,15 @@ using Bit.Api.AdminConsole.Controllers;
using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Auth.Models.Request.Accounts;
using Bit.Core; using Bit.Core;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Models.Business.Tokenables;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationApiKeys.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationApiKeys.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Auth.Entities; using Bit.Core.Auth.Entities;
using Bit.Core.Auth.Enums; using Bit.Core.Auth.Enums;
@ -55,6 +58,7 @@ public class OrganizationsControllerTests : IDisposable
private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand;
private readonly ICloudOrganizationSignUpCommand _cloudOrganizationSignUpCommand; private readonly ICloudOrganizationSignUpCommand _cloudOrganizationSignUpCommand;
private readonly IOrganizationDeleteCommand _organizationDeleteCommand; private readonly IOrganizationDeleteCommand _organizationDeleteCommand;
private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
private readonly OrganizationsController _sut; private readonly OrganizationsController _sut;
@ -80,6 +84,7 @@ public class OrganizationsControllerTests : IDisposable
_removeOrganizationUserCommand = Substitute.For<IRemoveOrganizationUserCommand>(); _removeOrganizationUserCommand = Substitute.For<IRemoveOrganizationUserCommand>();
_cloudOrganizationSignUpCommand = Substitute.For<ICloudOrganizationSignUpCommand>(); _cloudOrganizationSignUpCommand = Substitute.For<ICloudOrganizationSignUpCommand>();
_organizationDeleteCommand = Substitute.For<IOrganizationDeleteCommand>(); _organizationDeleteCommand = Substitute.For<IOrganizationDeleteCommand>();
_policyRequirementQuery = Substitute.For<IPolicyRequirementQuery>();
_pricingClient = Substitute.For<IPricingClient>(); _pricingClient = Substitute.For<IPricingClient>();
_sut = new OrganizationsController( _sut = new OrganizationsController(
@ -103,6 +108,7 @@ public class OrganizationsControllerTests : IDisposable
_removeOrganizationUserCommand, _removeOrganizationUserCommand,
_cloudOrganizationSignUpCommand, _cloudOrganizationSignUpCommand,
_organizationDeleteCommand, _organizationDeleteCommand,
_policyRequirementQuery,
_pricingClient); _pricingClient);
} }
@ -236,4 +242,55 @@ public class OrganizationsControllerTests : IDisposable
await _organizationDeleteCommand.Received(1).DeleteAsync(organization); await _organizationDeleteCommand.Received(1).DeleteAsync(organization);
} }
[Theory, AutoData]
public async Task GetAutoEnrollStatus_WithPolicyRequirementsEnabled_ReturnsOrganizationAutoEnrollStatus_WithResetPasswordEnabledTrue(
User user,
Organization organization,
OrganizationUser organizationUser
)
{
var policyRequirement = new ResetPasswordPolicyRequirement() { AutoEnrollOrganizations = [organization.Id] };
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_organizationRepository.GetByIdentifierAsync(organization.Id.ToString()).Returns(organization);
_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true);
_organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser);
_policyRequirementQuery.GetAsync<ResetPasswordPolicyRequirement>(user.Id).Returns(policyRequirement);
var result = await _sut.GetAutoEnrollStatus(organization.Id.ToString());
await _userService.Received(1).GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>());
await _organizationRepository.Received(1).GetByIdentifierAsync(organization.Id.ToString());
await _policyRequirementQuery.Received(1).GetAsync<ResetPasswordPolicyRequirement>(user.Id);
Assert.True(result.ResetPasswordEnabled);
Assert.Equal(result.Id, organization.Id);
}
[Theory, AutoData]
public async Task GetAutoEnrollStatus_WithPolicyRequirementsDisabled_ReturnsOrganizationAutoEnrollStatus_WithResetPasswordEnabledTrue(
User user,
Organization organization,
OrganizationUser organizationUser
)
{
var policy = new Policy() { Type = PolicyType.ResetPassword, Enabled = true, Data = "{\"AutoEnrollEnabled\": true}", OrganizationId = organization.Id };
_userService.GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>()).Returns(user);
_organizationRepository.GetByIdentifierAsync(organization.Id.ToString()).Returns(organization);
_featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(false);
_organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser);
_policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword).Returns(policy);
var result = await _sut.GetAutoEnrollStatus(organization.Id.ToString());
await _userService.Received(1).GetUserByPrincipalAsync(Arg.Any<ClaimsPrincipal>());
await _organizationRepository.Received(1).GetByIdentifierAsync(organization.Id.ToString());
await _policyRequirementQuery.Received(0).GetAsync<ResetPasswordPolicyRequirement>(user.Id);
await _policyRepository.Received(1).GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword);
Assert.True(result.ResetPasswordEnabled);
}
} }

View File

@ -15,16 +15,12 @@ using Bit.Core.Auth.Models.Api.Request.Accounts;
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.UserFeatures.TdeOffboardingPassword.Interfaces; using Bit.Core.Auth.UserFeatures.TdeOffboardingPassword.Interfaces;
using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces; using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces;
using Bit.Core.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.KeyManagement.UserKey; using Bit.Core.KeyManagement.UserKey;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities; using Bit.Core.Tools.Entities;
using Bit.Core.Tools.Services;
using Bit.Core.Vault.Entities; using Bit.Core.Vault.Entities;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Identity;
@ -37,10 +33,8 @@ public class AccountsControllerTests : IDisposable
{ {
private readonly AccountsController _sut; private readonly AccountsController _sut;
private readonly GlobalSettings _globalSettings;
private readonly IOrganizationService _organizationService; private readonly IOrganizationService _organizationService;
private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IPaymentService _paymentService;
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly IProviderUserRepository _providerUserRepository; private readonly IProviderUserRepository _providerUserRepository;
private readonly IPolicyService _policyService; private readonly IPolicyService _policyService;
@ -48,9 +42,6 @@ public class AccountsControllerTests : IDisposable
private readonly IRotateUserKeyCommand _rotateUserKeyCommand; private readonly IRotateUserKeyCommand _rotateUserKeyCommand;
private readonly ITdeOffboardingPasswordCommand _tdeOffboardingPasswordCommand; private readonly ITdeOffboardingPasswordCommand _tdeOffboardingPasswordCommand;
private readonly IFeatureService _featureService; private readonly IFeatureService _featureService;
private readonly ISubscriberService _subscriberService;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
private readonly IRotationValidator<IEnumerable<CipherWithIdRequestModel>, IEnumerable<Cipher>> _cipherValidator; private readonly IRotationValidator<IEnumerable<CipherWithIdRequestModel>, IEnumerable<Cipher>> _cipherValidator;
private readonly IRotationValidator<IEnumerable<FolderWithIdRequestModel>, IEnumerable<Folder>> _folderValidator; private readonly IRotationValidator<IEnumerable<FolderWithIdRequestModel>, IEnumerable<Folder>> _folderValidator;
@ -70,16 +61,11 @@ public class AccountsControllerTests : IDisposable
_organizationService = Substitute.For<IOrganizationService>(); _organizationService = Substitute.For<IOrganizationService>();
_organizationUserRepository = Substitute.For<IOrganizationUserRepository>(); _organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
_providerUserRepository = Substitute.For<IProviderUserRepository>(); _providerUserRepository = Substitute.For<IProviderUserRepository>();
_paymentService = Substitute.For<IPaymentService>();
_globalSettings = new GlobalSettings();
_policyService = Substitute.For<IPolicyService>(); _policyService = Substitute.For<IPolicyService>();
_setInitialMasterPasswordCommand = Substitute.For<ISetInitialMasterPasswordCommand>(); _setInitialMasterPasswordCommand = Substitute.For<ISetInitialMasterPasswordCommand>();
_rotateUserKeyCommand = Substitute.For<IRotateUserKeyCommand>(); _rotateUserKeyCommand = Substitute.For<IRotateUserKeyCommand>();
_tdeOffboardingPasswordCommand = Substitute.For<ITdeOffboardingPasswordCommand>(); _tdeOffboardingPasswordCommand = Substitute.For<ITdeOffboardingPasswordCommand>();
_featureService = Substitute.For<IFeatureService>(); _featureService = Substitute.For<IFeatureService>();
_subscriberService = Substitute.For<ISubscriberService>();
_referenceEventService = Substitute.For<IReferenceEventService>();
_currentContext = Substitute.For<ICurrentContext>();
_cipherValidator = _cipherValidator =
Substitute.For<IRotationValidator<IEnumerable<CipherWithIdRequestModel>, IEnumerable<Cipher>>>(); Substitute.For<IRotationValidator<IEnumerable<CipherWithIdRequestModel>, IEnumerable<Cipher>>>();
_folderValidator = _folderValidator =
@ -93,20 +79,15 @@ public class AccountsControllerTests : IDisposable
IReadOnlyList<OrganizationUser>>>(); IReadOnlyList<OrganizationUser>>>();
_sut = new AccountsController( _sut = new AccountsController(
_globalSettings,
_organizationService, _organizationService,
_organizationUserRepository, _organizationUserRepository,
_providerUserRepository, _providerUserRepository,
_paymentService,
_userService, _userService,
_policyService, _policyService,
_setInitialMasterPasswordCommand, _setInitialMasterPasswordCommand,
_tdeOffboardingPasswordCommand, _tdeOffboardingPasswordCommand,
_rotateUserKeyCommand, _rotateUserKeyCommand,
_featureService, _featureService,
_subscriberService,
_referenceEventService,
_currentContext,
_cipherValidator, _cipherValidator,
_folderValidator, _folderValidator,
_sendValidator, _sendValidator,

View File

@ -243,20 +243,22 @@ public class PushControllerTests
PushToken = "test-push-token", PushToken = "test-push-token",
UserId = userId.ToString(), UserId = userId.ToString(),
Type = DeviceType.Android, Type = DeviceType.Android,
Identifier = identifier.ToString() Identifier = identifier.ToString(),
})); }));
Assert.Equal("Not correctly configured for push relays.", exception.Message); Assert.Equal("Not correctly configured for push relays.", exception.Message);
await sutProvider.GetDependency<IPushRegistrationService>().Received(0) await sutProvider.GetDependency<IPushRegistrationService>().Received(0)
.CreateOrUpdateRegistrationAsync(Arg.Any<PushRegistrationData>(), Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string>(), .CreateOrUpdateRegistrationAsync(Arg.Any<PushRegistrationData>(), Arg.Any<string>(), Arg.Any<string>(),
Arg.Any<DeviceType>(), Arg.Any<IEnumerable<string>>(), Arg.Any<Guid>()); Arg.Any<string>(), Arg.Any<DeviceType>(), Arg.Any<IEnumerable<string>>(), Arg.Any<Guid>());
} }
[Theory] [Theory]
[BitAutoData] [BitAutoData(false)]
public async Task? RegisterAsync_ValidModel_CreatedOrUpdatedRegistration(SutProvider<PushController> sutProvider, [BitAutoData(true)]
Guid installationId, Guid userId, Guid identifier, Guid deviceId, Guid organizationId) public async Task RegisterAsync_ValidModel_CreatedOrUpdatedRegistration(bool haveOrganizationId,
SutProvider<PushController> sutProvider, Guid installationId, Guid userId, Guid identifier, Guid deviceId,
Guid organizationId)
{ {
sutProvider.GetDependency<IGlobalSettings>().SelfHosted = false; sutProvider.GetDependency<IGlobalSettings>().SelfHosted = false;
sutProvider.GetDependency<ICurrentContext>().InstallationId.Returns(installationId); sutProvider.GetDependency<ICurrentContext>().InstallationId.Returns(installationId);
@ -273,19 +275,29 @@ public class PushControllerTests
UserId = userId.ToString(), UserId = userId.ToString(),
Type = DeviceType.Android, Type = DeviceType.Android,
Identifier = identifier.ToString(), Identifier = identifier.ToString(),
OrganizationIds = [organizationId.ToString()], OrganizationIds = haveOrganizationId ? [organizationId.ToString()] : null,
InstallationId = installationId InstallationId = installationId
}; };
await sutProvider.Sut.RegisterAsync(model); await sutProvider.Sut.RegisterAsync(model);
await sutProvider.GetDependency<IPushRegistrationService>().Received(1) await sutProvider.GetDependency<IPushRegistrationService>().Received(1)
.CreateOrUpdateRegistrationAsync(Arg.Is<PushRegistrationData>(data => data == new PushRegistrationData(model.PushToken)), expectedDeviceId, expectedUserId, .CreateOrUpdateRegistrationAsync(
Arg.Is<PushRegistrationData>(data => data == new PushRegistrationData(model.PushToken)),
expectedDeviceId, expectedUserId,
expectedIdentifier, DeviceType.Android, Arg.Do<IEnumerable<string>>(organizationIds => expectedIdentifier, DeviceType.Android, Arg.Do<IEnumerable<string>>(organizationIds =>
{ {
Assert.NotNull(organizationIds);
var organizationIdsList = organizationIds.ToList(); var organizationIdsList = organizationIds.ToList();
Assert.Contains(expectedOrganizationId, organizationIdsList); if (haveOrganizationId)
Assert.Single(organizationIdsList); {
Assert.Contains(expectedOrganizationId, organizationIdsList);
Assert.Single(organizationIdsList);
}
else
{
Assert.Empty(organizationIdsList);
}
}), installationId); }), installationId);
} }
} }

View File

@ -79,7 +79,8 @@ public class ImportCiphersControllerTests
.ImportIntoIndividualVaultAsync( .ImportIntoIndividualVaultAsync(
Arg.Any<List<Folder>>(), Arg.Any<List<Folder>>(),
Arg.Any<List<CipherDetails>>(), Arg.Any<List<CipherDetails>>(),
Arg.Any<IEnumerable<KeyValuePair<int, int>>>() Arg.Any<IEnumerable<KeyValuePair<int, int>>>(),
user.Id
); );
} }

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies;
/// <summary>
/// Intentionally simplified PolicyRequirement that just holds the input PolicyDetails for us to assert against.
/// </summary>
public class TestPolicyRequirement : IPolicyRequirement
{
public IEnumerable<PolicyDetails> Policies { get; init; } = [];
}
public class TestPolicyRequirementFactory(Func<PolicyDetails, bool> enforce) : IPolicyRequirementFactory<TestPolicyRequirement>
{
public PolicyType PolicyType => PolicyType.SingleOrg;
public bool Enforce(PolicyDetails policyDetails) => enforce(policyDetails);
public TestPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails)
=> new() { Policies = policyDetails };
}

View File

@ -1,6 +1,6 @@
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Repositories;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute; using NSubstitute;
@ -11,50 +11,72 @@ namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies;
[SutProviderCustomize] [SutProviderCustomize]
public class PolicyRequirementQueryTests public class PolicyRequirementQueryTests
{ {
/// <summary>
/// Tests that the query correctly registers, retrieves and instantiates arbitrary IPolicyRequirements
/// according to their provided CreateRequirement delegate.
/// </summary>
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetAsync_Works(Guid userId, Guid organizationId) public async Task GetAsync_IgnoresOtherPolicyTypes(Guid userId)
{ {
var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
var otherPolicy = new PolicyDetails { PolicyType = PolicyType.RequireSso };
var policyRepository = Substitute.For<IPolicyRepository>(); var policyRepository = Substitute.For<IPolicyRepository>();
var factories = new List<RequirementFactory<IPolicyRequirement>> policyRepository.GetPolicyDetailsByUserId(userId).Returns([otherPolicy, thisPolicy]);
{
// In prod this cast is handled when the CreateRequirement delegate is registered in DI
(RequirementFactory<TestPolicyRequirement>)TestPolicyRequirement.Create
};
var sut = new PolicyRequirementQuery(policyRepository, factories); var factory = new TestPolicyRequirementFactory(_ => true);
policyRepository.GetPolicyDetailsByUserId(userId).Returns([ var sut = new PolicyRequirementQuery(policyRepository, [factory]);
new PolicyDetails
{
OrganizationId = organizationId
}
]);
var requirement = await sut.GetAsync<TestPolicyRequirement>(userId); var requirement = await sut.GetAsync<TestPolicyRequirement>(userId);
Assert.Equal(organizationId, requirement.OrganizationId);
Assert.Contains(thisPolicy, requirement.Policies);
Assert.DoesNotContain(otherPolicy, requirement.Policies);
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task GetAsync_ThrowsIfNoRequirementRegistered(Guid userId) public async Task GetAsync_CallsEnforceCallback(Guid userId)
{
// Arrange policies
var policyRepository = Substitute.For<IPolicyRepository>();
var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
var otherPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
policyRepository.GetPolicyDetailsByUserId(userId).Returns([thisPolicy, otherPolicy]);
// Arrange a substitute Enforce function so that we can inspect the received calls
var callback = Substitute.For<Func<PolicyDetails, bool>>();
callback(Arg.Any<PolicyDetails>()).Returns(x => x.Arg<PolicyDetails>() == thisPolicy);
// Arrange the sut
var factory = new TestPolicyRequirementFactory(callback);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
// Act
var requirement = await sut.GetAsync<TestPolicyRequirement>(userId);
// Assert
Assert.Contains(thisPolicy, requirement.Policies);
Assert.DoesNotContain(otherPolicy, requirement.Policies);
callback.Received()(Arg.Is(thisPolicy));
callback.Received()(Arg.Is(otherPolicy));
}
[Theory, BitAutoData]
public async Task GetAsync_ThrowsIfNoFactoryRegistered(Guid userId)
{ {
var policyRepository = Substitute.For<IPolicyRepository>(); var policyRepository = Substitute.For<IPolicyRepository>();
var sut = new PolicyRequirementQuery(policyRepository, []); var sut = new PolicyRequirementQuery(policyRepository, []);
var exception = await Assert.ThrowsAsync<NotImplementedException>(() var exception = await Assert.ThrowsAsync<NotImplementedException>(()
=> sut.GetAsync<TestPolicyRequirement>(userId)); => sut.GetAsync<TestPolicyRequirement>(userId));
Assert.Contains("No Policy Requirement found", exception.Message); Assert.Contains("No Requirement Factory found", exception.Message);
} }
/// <summary> [Theory, BitAutoData]
/// Intentionally simplified PolicyRequirement that just holds the Policy.OrganizationId for us to assert against. public async Task GetAsync_HandlesNoPolicies(Guid userId)
/// </summary>
private class TestPolicyRequirement : IPolicyRequirement
{ {
public Guid OrganizationId { get; init; } var policyRepository = Substitute.For<IPolicyRepository>();
public static TestPolicyRequirement Create(IEnumerable<PolicyDetails> policyDetails) policyRepository.GetPolicyDetailsByUserId(userId).Returns([]);
=> new() { OrganizationId = policyDetails.Single().OrganizationId };
var factory = new TestPolicyRequirementFactory(x => x.IsProvider);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
var requirement = await sut.GetAsync<TestPolicyRequirement>(userId);
Assert.Empty(requirement.Policies);
} }
} }

Some files were not shown because too many files have changed in this diff Show More