1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-25 23:02:17 -05:00

Merge branch 'main' into PM-16921

This commit is contained in:
Conner Turnbull 2025-04-14 15:09:58 -04:00 committed by GitHub
commit f3b3b6cb79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
77 changed files with 4710 additions and 290 deletions

2
.github/CODEOWNERS vendored
View File

@ -37,6 +37,8 @@ util/Setup/** @bitwarden/dept-bre @bitwarden/team-platform-dev
**/Auth @bitwarden/team-auth-dev **/Auth @bitwarden/team-auth-dev
bitwarden_license/src/Sso @bitwarden/team-auth-dev bitwarden_license/src/Sso @bitwarden/team-auth-dev
src/Identity @bitwarden/team-auth-dev src/Identity @bitwarden/team-auth-dev
src/Core/Identity @bitwarden/team-auth-dev
src/Core/IdentityServer @bitwarden/team-auth-dev
# Key Management team # Key Management team
**/KeyManagement @bitwarden/team-key-management-dev **/KeyManagement @bitwarden/team-key-management-dev

View File

@ -627,55 +627,16 @@ jobs:
} }
}) })
trigger-ee-updates: setup-ephemeral-environment:
name: Trigger Ephemeral Environment updates name: Setup Ephemeral Environment
if: | needs: build-docker
needs.build-artifacts.outputs.has_secrets == 'true'
&& github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'ephemeral-environment')
runs-on: ubuntu-24.04
needs:
- build-docker
steps:
- name: Log in to Azure - CI subscription
uses: Azure/login@e15b166166a8746d1a47596803bd8c1b595455cf # v1.6.0
with:
creds: ${{ secrets.AZURE_KV_CI_SERVICE_PRINCIPAL }}
- name: Retrieve GitHub PAT secrets
id: retrieve-secret-pat
uses: bitwarden/gh-actions/get-keyvault-secrets@main
with:
keyvault: "bitwarden-ci"
secrets: "github-pat-bitwarden-devops-bot-repo-scope"
- name: Trigger Ephemeral Environment update
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}
script: |
await github.rest.actions.createWorkflowDispatch({
owner: 'bitwarden',
repo: 'devops',
workflow_id: '_update_ephemeral_tags.yml',
ref: 'main',
inputs: {
ephemeral_env_branch: process.env.GITHUB_HEAD_REF
}
})
trigger-ephemeral-environment-sync:
name: Trigger Ephemeral Environment Sync
needs: trigger-ee-updates
if: | if: |
needs.build-artifacts.outputs.has_secrets == 'true' needs.build-artifacts.outputs.has_secrets == 'true'
&& github.event_name == 'pull_request' && github.event_name == 'pull_request'
&& contains(github.event.pull_request.labels.*.name, 'ephemeral-environment') && contains(github.event.pull_request.labels.*.name, 'ephemeral-environment')
uses: bitwarden/gh-actions/.github/workflows/_ephemeral_environment_manager.yml@main uses: bitwarden/gh-actions/.github/workflows/_ephemeral_environment_manager.yml@main
with: with:
ephemeral_env_branch: process.env.GITHUB_HEAD_REF
project: server project: server
sync_environment: true
pull_request_number: ${{ github.event.number }} pull_request_number: ${{ github.event.number }}
secrets: inherit secrets: inherit

View File

@ -5,34 +5,12 @@ on:
types: [labeled] types: [labeled]
jobs: jobs:
trigger-ee-updates: setup-ephemeral-environment:
name: Trigger Ephemeral Environment updates name: Setup Ephemeral Environment
runs-on: ubuntu-24.04
if: github.event.label.name == 'ephemeral-environment' if: github.event.label.name == 'ephemeral-environment'
steps: uses: bitwarden/gh-actions/.github/workflows/_ephemeral_environment_manager.yml@main
- name: Log in to Azure - CI subscription
uses: Azure/login@e15b166166a8746d1a47596803bd8c1b595455cf # v1.6.0
with: with:
creds: ${{ secrets.AZURE_KV_CI_SERVICE_PRINCIPAL }} project: server
pull_request_number: ${{ github.event.number }}
- name: Retrieve GitHub PAT secrets sync_environment: true
id: retrieve-secret-pat secrets: inherit
uses: bitwarden/gh-actions/get-keyvault-secrets@main
with:
keyvault: "bitwarden-ci"
secrets: "github-pat-bitwarden-devops-bot-repo-scope"
- name: Trigger Ephemeral Environment update
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}
script: |
await github.rest.actions.createWorkflowDispatch({
owner: 'bitwarden',
repo: 'devops',
workflow_id: '_update_ephemeral_tags.yml',
ref: 'main',
inputs: {
ephemeral_env_branch: process.env.GITHUB_HEAD_REF
}
})

View File

@ -48,7 +48,7 @@ public class CreateProviderCommand : ICreateProviderCommand
await ProviderRepositoryCreateAsync(provider, ProviderStatusType.Created); await ProviderRepositoryCreateAsync(provider, ProviderStatusType.Created);
} }
public async Task CreateMultiOrganizationEnterpriseAsync(Provider provider, string ownerEmail, PlanType plan, int minimumSeats) public async Task CreateBusinessUnitAsync(Provider provider, string ownerEmail, PlanType plan, int minimumSeats)
{ {
var providerId = await CreateProviderAsync(provider, ownerEmail); var providerId = await CreateProviderAsync(provider, ownerEmail);

View File

@ -692,10 +692,10 @@ public class ProviderService : IProviderService
throw new BadRequestException($"Managed Service Providers cannot manage organizations with the plan type {requestedType}. Only Teams (Monthly) and Enterprise (Monthly) are allowed."); throw new BadRequestException($"Managed Service Providers cannot manage organizations with the plan type {requestedType}. Only Teams (Monthly) and Enterprise (Monthly) are allowed.");
} }
break; break;
case ProviderType.MultiOrganizationEnterprise: case ProviderType.BusinessUnit:
if (requestedType is not (PlanType.EnterpriseMonthly or PlanType.EnterpriseAnnually)) if (requestedType is not (PlanType.EnterpriseMonthly or PlanType.EnterpriseAnnually))
{ {
throw new BadRequestException($"Multi-organization Enterprise Providers cannot manage organizations with the plan type {requestedType}. Only Enterprise (Monthly) and Enterprise (Annually) are allowed."); throw new BadRequestException($"Business Unit Providers cannot manage organizations with the plan type {requestedType}. Only Enterprise (Monthly) and Enterprise (Annually) are allowed.");
} }
break; break;
case ProviderType.Reseller: case ProviderType.Reseller:

View File

@ -0,0 +1,462 @@
#nullable enable
using System.Diagnostics.CodeAnalysis;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Extensions;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.DataProtection;
using Microsoft.Extensions.Logging;
using OneOf;
using Stripe;
namespace Bit.Commercial.Core.Billing;
[RequireFeature(FeatureFlagKeys.PM18770_EnableOrganizationBusinessUnitConversion)]
public class BusinessUnitConverter(
IDataProtectionProvider dataProtectionProvider,
GlobalSettings globalSettings,
ILogger<BusinessUnitConverter> logger,
IMailService mailService,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IPricingClient pricingClient,
IProviderOrganizationRepository providerOrganizationRepository,
IProviderPlanRepository providerPlanRepository,
IProviderRepository providerRepository,
IProviderUserRepository providerUserRepository,
IStripeAdapter stripeAdapter,
ISubscriberService subscriberService,
IUserRepository userRepository) : IBusinessUnitConverter
{
private readonly IDataProtector _dataProtector =
dataProtectionProvider.CreateProtector($"{nameof(BusinessUnitConverter)}DataProtector");
public async Task<Guid> FinalizeConversion(
Organization organization,
Guid userId,
string token,
string providerKey,
string organizationKey)
{
var user = await userRepository.GetByIdAsync(userId);
var (subscription, provider, providerOrganization, providerUser) = await ValidateFinalizationAsync(organization, user, token);
var existingPlan = await pricingClient.GetPlanOrThrow(organization.PlanType);
var updatedPlan = await pricingClient.GetPlanOrThrow(existingPlan.IsAnnual ? PlanType.EnterpriseAnnually : PlanType.EnterpriseMonthly);
// Bring organization under management.
organization.Plan = updatedPlan.Name;
organization.PlanType = updatedPlan.Type;
organization.MaxCollections = updatedPlan.PasswordManager.MaxCollections;
organization.MaxStorageGb = updatedPlan.PasswordManager.BaseStorageGb;
organization.UsePolicies = updatedPlan.HasPolicies;
organization.UseSso = updatedPlan.HasSso;
organization.UseGroups = updatedPlan.HasGroups;
organization.UseEvents = updatedPlan.HasEvents;
organization.UseDirectory = updatedPlan.HasDirectory;
organization.UseTotp = updatedPlan.HasTotp;
organization.Use2fa = updatedPlan.Has2fa;
organization.UseApi = updatedPlan.HasApi;
organization.UseResetPassword = updatedPlan.HasResetPassword;
organization.SelfHost = updatedPlan.HasSelfHost;
organization.UsersGetPremium = updatedPlan.UsersGetPremium;
organization.UseCustomPermissions = updatedPlan.HasCustomPermissions;
organization.UseScim = updatedPlan.HasScim;
organization.UseKeyConnector = updatedPlan.HasKeyConnector;
organization.MaxStorageGb = updatedPlan.PasswordManager.BaseStorageGb;
organization.BillingEmail = provider.BillingEmail!;
organization.GatewayCustomerId = null;
organization.GatewaySubscriptionId = null;
organization.ExpirationDate = null;
organization.MaxAutoscaleSeats = null;
organization.Status = OrganizationStatusType.Managed;
// Enable organization access via key exchange.
providerOrganization.Key = organizationKey;
// Complete provider setup.
provider.Gateway = GatewayType.Stripe;
provider.GatewayCustomerId = subscription.CustomerId;
provider.GatewaySubscriptionId = subscription.Id;
provider.Status = ProviderStatusType.Billable;
// Enable provider access via key exchange.
providerUser.Key = providerKey;
providerUser.Status = ProviderUserStatusType.Confirmed;
// Stripe requires that we clear all the custom fields from the invoice settings if we want to replace them.
await stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions
{
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields = []
}
});
var metadata = new Dictionary<string, string>
{
[StripeConstants.MetadataKeys.OrganizationId] = string.Empty,
[StripeConstants.MetadataKeys.ProviderId] = provider.Id.ToString(),
["convertedFrom"] = organization.Id.ToString()
};
var updateCustomer = stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions
{
InvoiceSettings = new CustomerInvoiceSettingsOptions
{
CustomFields = [
new CustomerInvoiceSettingsCustomFieldOptions
{
Name = provider.SubscriberType(),
Value = provider.DisplayName()?.Length <= 30
? provider.DisplayName()
: provider.DisplayName()?[..30]
}
]
},
Metadata = metadata
});
// Find the existing password manager price on the subscription.
var passwordManagerItem = subscription.Items.First(item =>
{
var priceId = existingPlan.HasNonSeatBasedPasswordManagerPlan()
? existingPlan.PasswordManager.StripePlanId
: existingPlan.PasswordManager.StripeSeatPlanId;
return item.Price.Id == priceId;
});
// Get the new business unit price.
var updatedPriceId = ProviderPriceAdapter.GetActivePriceId(provider, updatedPlan.Type);
// Replace the existing password manager price with the new business unit price.
var updateSubscription =
stripeAdapter.SubscriptionUpdateAsync(subscription.Id,
new SubscriptionUpdateOptions
{
Items = [
new SubscriptionItemOptions
{
Id = passwordManagerItem.Id,
Deleted = true
},
new SubscriptionItemOptions
{
Price = updatedPriceId,
Quantity = organization.Seats
}
],
Metadata = metadata
});
await Task.WhenAll(updateCustomer, updateSubscription);
// Complete database updates for provider setup.
await Task.WhenAll(
organizationRepository.ReplaceAsync(organization),
providerOrganizationRepository.ReplaceAsync(providerOrganization),
providerRepository.ReplaceAsync(provider),
providerUserRepository.ReplaceAsync(providerUser));
return provider.Id;
}
public async Task<OneOf<Guid, List<string>>> InitiateConversion(
Organization organization,
string providerAdminEmail)
{
var user = await userRepository.GetByEmailAsync(providerAdminEmail);
var problems = await ValidateInitiationAsync(organization, user);
if (problems is { Count: > 0 })
{
return problems;
}
var provider = await providerRepository.CreateAsync(new Provider
{
Name = organization.Name,
BillingEmail = organization.BillingEmail,
Status = ProviderStatusType.Pending,
UseEvents = true,
Type = ProviderType.BusinessUnit
});
var plan = await pricingClient.GetPlanOrThrow(organization.PlanType);
var managedPlanType = plan.IsAnnual
? PlanType.EnterpriseAnnually
: PlanType.EnterpriseMonthly;
var createProviderOrganization = providerOrganizationRepository.CreateAsync(new ProviderOrganization
{
ProviderId = provider.Id,
OrganizationId = organization.Id
});
var createProviderPlan = providerPlanRepository.CreateAsync(new ProviderPlan
{
ProviderId = provider.Id,
PlanType = managedPlanType,
SeatMinimum = 0,
PurchasedSeats = organization.Seats,
AllocatedSeats = organization.Seats
});
var createProviderUser = providerUserRepository.CreateAsync(new ProviderUser
{
ProviderId = provider.Id,
UserId = user!.Id,
Email = user.Email,
Status = ProviderUserStatusType.Invited,
Type = ProviderUserType.ProviderAdmin
});
await Task.WhenAll(createProviderOrganization, createProviderPlan, createProviderUser);
await SendInviteAsync(organization, user.Email);
return provider.Id;
}
public Task ResendConversionInvite(
Organization organization,
string providerAdminEmail) =>
IfConversionInProgressAsync(organization, providerAdminEmail,
async (_, _, providerUser) =>
{
if (!string.IsNullOrEmpty(providerUser.Email))
{
await SendInviteAsync(organization, providerUser.Email);
}
});
public Task ResetConversion(
Organization organization,
string providerAdminEmail) =>
IfConversionInProgressAsync(organization, providerAdminEmail,
async (provider, providerOrganization, providerUser) =>
{
var tasks = new List<Task>
{
providerOrganizationRepository.DeleteAsync(providerOrganization),
providerUserRepository.DeleteAsync(providerUser)
};
var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id);
if (providerPlans is { Count: > 0 })
{
tasks.AddRange(providerPlans.Select(providerPlanRepository.DeleteAsync));
}
await Task.WhenAll(tasks);
await providerRepository.DeleteAsync(provider);
});
#region Utilities
private async Task IfConversionInProgressAsync(
Organization organization,
string providerAdminEmail,
Func<Provider, ProviderOrganization, ProviderUser, Task> callback)
{
var user = await userRepository.GetByEmailAsync(providerAdminEmail);
if (user == null)
{
return;
}
var provider = await providerRepository.GetByOrganizationIdAsync(organization.Id);
if (provider is not
{
Type: ProviderType.BusinessUnit,
Status: ProviderStatusType.Pending
})
{
return;
}
var providerUser = await providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id);
if (providerUser is
{
Type: ProviderUserType.ProviderAdmin,
Status: ProviderUserStatusType.Invited
})
{
var providerOrganization = await providerOrganizationRepository.GetByOrganizationId(organization.Id);
await callback(provider, providerOrganization!, providerUser);
}
}
private async Task SendInviteAsync(
Organization organization,
string providerAdminEmail)
{
var token = _dataProtector.Protect(
$"BusinessUnitConversionInvite {organization.Id} {providerAdminEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}");
await mailService.SendBusinessUnitConversionInviteAsync(organization, token, providerAdminEmail);
}
private async Task<(Subscription, Provider, ProviderOrganization, ProviderUser)> ValidateFinalizationAsync(
Organization organization,
User? user,
string token)
{
if (organization.PlanType.GetProductTier() != ProductTierType.Enterprise)
{
Fail("Organization must be on an enterprise plan.");
}
var subscription = await subscriberService.GetSubscription(organization);
if (subscription is not
{
Status:
StripeConstants.SubscriptionStatus.Active or
StripeConstants.SubscriptionStatus.Trialing or
StripeConstants.SubscriptionStatus.PastDue
})
{
Fail("Organization must have a valid subscription.");
}
if (user == null)
{
Fail("Provider admin must be a Bitwarden user.");
}
if (!CoreHelpers.TokenIsValid(
"BusinessUnitConversionInvite",
_dataProtector,
token,
user.Email,
organization.Id,
globalSettings.OrganizationInviteExpirationHours))
{
Fail("Email token is invalid.");
}
var organizationUser =
await organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id);
if (organizationUser is not
{
Status: OrganizationUserStatusType.Confirmed
})
{
Fail("Provider admin must be a confirmed member of the organization being converted.");
}
var provider = await providerRepository.GetByOrganizationIdAsync(organization.Id);
if (provider is not
{
Type: ProviderType.BusinessUnit,
Status: ProviderStatusType.Pending
})
{
Fail("Linked provider is not a pending business unit.");
}
var providerUser = await providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id);
if (providerUser is not
{
Type: ProviderUserType.ProviderAdmin,
Status: ProviderUserStatusType.Invited
})
{
Fail("Provider admin has not been invited.");
}
var providerOrganization = await providerOrganizationRepository.GetByOrganizationId(organization.Id);
return (subscription, provider, providerOrganization!, providerUser);
[DoesNotReturn]
void Fail(string scopedError)
{
logger.LogError("Could not finalize business unit conversion for organization ({OrganizationID}): {Error}",
organization.Id, scopedError);
throw new BillingException();
}
}
private async Task<List<string>?> ValidateInitiationAsync(
Organization organization,
User? user)
{
var problems = new List<string>();
if (organization.PlanType.GetProductTier() != ProductTierType.Enterprise)
{
problems.Add("Organization must be on an enterprise plan.");
}
var subscription = await subscriberService.GetSubscription(organization);
if (subscription is not
{
Status:
StripeConstants.SubscriptionStatus.Active or
StripeConstants.SubscriptionStatus.Trialing or
StripeConstants.SubscriptionStatus.PastDue
})
{
problems.Add("Organization must have a valid subscription.");
}
var providerOrganization = await providerOrganizationRepository.GetByOrganizationId(organization.Id);
if (providerOrganization != null)
{
problems.Add("Organization is already linked to a provider.");
}
if (user == null)
{
problems.Add("Provider admin must be a Bitwarden user.");
}
else
{
var organizationUser =
await organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id);
if (organizationUser is not
{
Status: OrganizationUserStatusType.Confirmed
})
{
problems.Add("Provider admin must be a confirmed member of the organization being converted.");
}
}
return problems.Count == 0 ? null : problems;
}
#endregion
}

View File

@ -791,7 +791,7 @@ public class ProviderBillingService(
Provider provider, Provider provider,
Organization organization) Organization organization)
{ {
if (provider.Type == ProviderType.MultiOrganizationEnterprise) if (provider.Type == ProviderType.BusinessUnit)
{ {
return (await providerPlanRepository.GetByProviderId(provider.Id)).First().PlanType; return (await providerPlanRepository.GetByProviderId(provider.Id)).First().PlanType;
} }

View File

@ -51,7 +51,7 @@ public static class ProviderPriceAdapter
/// <param name="subscription">The provider's subscription.</param> /// <param name="subscription">The provider's subscription.</param>
/// <param name="planType">The plan type correlating to the desired Stripe price ID.</param> /// <param name="planType">The plan type correlating to the desired Stripe price ID.</param>
/// <returns>A Stripe <see cref="Stripe.Price"/> ID.</returns> /// <returns>A Stripe <see cref="Stripe.Price"/> ID.</returns>
/// <exception cref="BillingException">Thrown when the provider's type is not <see cref="ProviderType.Msp"/> or <see cref="ProviderType.MultiOrganizationEnterprise"/>.</exception> /// <exception cref="BillingException">Thrown when the provider's type is not <see cref="ProviderType.Msp"/> or <see cref="ProviderType.BusinessUnit"/>.</exception>
/// <exception cref="BillingException">Thrown when the provided <see cref="planType"/> does not relate to a Stripe price ID.</exception> /// <exception cref="BillingException">Thrown when the provided <see cref="planType"/> does not relate to a Stripe price ID.</exception>
public static string GetPriceId( public static string GetPriceId(
Provider provider, Provider provider,
@ -78,7 +78,7 @@ public static class ProviderPriceAdapter
PlanType.EnterpriseMonthly => MSP.Active.Enterprise, PlanType.EnterpriseMonthly => MSP.Active.Enterprise,
_ => throw invalidPlanType _ => throw invalidPlanType
}, },
ProviderType.MultiOrganizationEnterprise => BusinessUnit.Legacy.List.Intersect(priceIds).Any() ProviderType.BusinessUnit => BusinessUnit.Legacy.List.Intersect(priceIds).Any()
? planType switch ? planType switch
{ {
PlanType.EnterpriseAnnually => BusinessUnit.Legacy.Annually, PlanType.EnterpriseAnnually => BusinessUnit.Legacy.Annually,
@ -103,7 +103,7 @@ public static class ProviderPriceAdapter
/// <param name="provider">The provider to get the Stripe price ID for.</param> /// <param name="provider">The provider to get the Stripe price ID for.</param>
/// <param name="planType">The plan type correlating to the desired Stripe price ID.</param> /// <param name="planType">The plan type correlating to the desired Stripe price ID.</param>
/// <returns>A Stripe <see cref="Stripe.Price"/> ID.</returns> /// <returns>A Stripe <see cref="Stripe.Price"/> ID.</returns>
/// <exception cref="BillingException">Thrown when the provider's type is not <see cref="ProviderType.Msp"/> or <see cref="ProviderType.MultiOrganizationEnterprise"/>.</exception> /// <exception cref="BillingException">Thrown when the provider's type is not <see cref="ProviderType.Msp"/> or <see cref="ProviderType.BusinessUnit"/>.</exception>
/// <exception cref="BillingException">Thrown when the provided <see cref="planType"/> does not relate to a Stripe price ID.</exception> /// <exception cref="BillingException">Thrown when the provided <see cref="planType"/> does not relate to a Stripe price ID.</exception>
public static string GetActivePriceId( public static string GetActivePriceId(
Provider provider, Provider provider,
@ -120,7 +120,7 @@ public static class ProviderPriceAdapter
PlanType.EnterpriseMonthly => MSP.Active.Enterprise, PlanType.EnterpriseMonthly => MSP.Active.Enterprise,
_ => throw invalidPlanType _ => throw invalidPlanType
}, },
ProviderType.MultiOrganizationEnterprise => planType switch ProviderType.BusinessUnit => planType switch
{ {
PlanType.EnterpriseAnnually => BusinessUnit.Active.Annually, PlanType.EnterpriseAnnually => BusinessUnit.Active.Annually,
PlanType.EnterpriseMonthly => BusinessUnit.Active.Monthly, PlanType.EnterpriseMonthly => BusinessUnit.Active.Monthly,

View File

@ -16,5 +16,6 @@ public static class ServiceCollectionExtensions
services.AddScoped<ICreateProviderCommand, CreateProviderCommand>(); services.AddScoped<ICreateProviderCommand, CreateProviderCommand>();
services.AddScoped<IRemoveOrganizationFromProviderCommand, RemoveOrganizationFromProviderCommand>(); services.AddScoped<IRemoveOrganizationFromProviderCommand, RemoveOrganizationFromProviderCommand>();
services.AddTransient<IProviderBillingService, ProviderBillingService>(); services.AddTransient<IProviderBillingService, ProviderBillingService>();
services.AddTransient<IBusinessUnitConverter, BusinessUnitConverter>();
} }
} }

View File

@ -63,7 +63,7 @@ public class CreateProviderCommandTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task CreateMultiOrganizationEnterpriseAsync_Success( public async Task CreateBusinessUnitAsync_Success(
Provider provider, Provider provider,
User user, User user,
PlanType plan, PlanType plan,
@ -71,13 +71,13 @@ public class CreateProviderCommandTests
SutProvider<CreateProviderCommand> sutProvider) SutProvider<CreateProviderCommand> sutProvider)
{ {
// Arrange // Arrange
provider.Type = ProviderType.MultiOrganizationEnterprise; provider.Type = ProviderType.BusinessUnit;
var userRepository = sutProvider.GetDependency<IUserRepository>(); var userRepository = sutProvider.GetDependency<IUserRepository>();
userRepository.GetByEmailAsync(user.Email).Returns(user); userRepository.GetByEmailAsync(user.Email).Returns(user);
// Act // Act
await sutProvider.Sut.CreateMultiOrganizationEnterpriseAsync(provider, user.Email, plan, minimumSeats); await sutProvider.Sut.CreateBusinessUnitAsync(provider, user.Email, plan, minimumSeats);
// Assert // Assert
await sutProvider.GetDependency<IProviderRepository>().ReceivedWithAnyArgs().CreateAsync(provider); await sutProvider.GetDependency<IProviderRepository>().ReceivedWithAnyArgs().CreateAsync(provider);
@ -85,7 +85,7 @@ public class CreateProviderCommandTests
} }
[Theory, BitAutoData] [Theory, BitAutoData]
public async Task CreateMultiOrganizationEnterpriseAsync_UserIdIsInvalid_Throws( public async Task CreateBusinessUnitAsync_UserIdIsInvalid_Throws(
Provider provider, Provider provider,
SutProvider<CreateProviderCommand> sutProvider) SutProvider<CreateProviderCommand> sutProvider)
{ {
@ -94,7 +94,7 @@ public class CreateProviderCommandTests
// Act // Act
var exception = await Assert.ThrowsAsync<BadRequestException>( var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.CreateMultiOrganizationEnterpriseAsync(provider, default, default, default)); () => sutProvider.Sut.CreateBusinessUnitAsync(provider, default, default, default));
// Assert // Assert
Assert.Contains("Invalid owner.", exception.Message); Assert.Contains("Invalid owner.", exception.Message);

View File

@ -0,0 +1,501 @@
#nullable enable
using System.Text;
using Bit.Commercial.Core.Billing;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing;
using Bit.Core.Billing.Constants;
using Bit.Core.Billing.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Billing.Pricing;
using Bit.Core.Billing.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.DataProtection;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
namespace Bit.Commercial.Core.Test.Billing;
public class BusinessUnitConverterTests
{
private readonly IDataProtectionProvider _dataProtectionProvider = Substitute.For<IDataProtectionProvider>();
private readonly GlobalSettings _globalSettings = new();
private readonly ILogger<BusinessUnitConverter> _logger = Substitute.For<ILogger<BusinessUnitConverter>>();
private readonly IMailService _mailService = Substitute.For<IMailService>();
private readonly IOrganizationRepository _organizationRepository = Substitute.For<IOrganizationRepository>();
private readonly IOrganizationUserRepository _organizationUserRepository = Substitute.For<IOrganizationUserRepository>();
private readonly IPricingClient _pricingClient = Substitute.For<IPricingClient>();
private readonly IProviderOrganizationRepository _providerOrganizationRepository = Substitute.For<IProviderOrganizationRepository>();
private readonly IProviderPlanRepository _providerPlanRepository = Substitute.For<IProviderPlanRepository>();
private readonly IProviderRepository _providerRepository = Substitute.For<IProviderRepository>();
private readonly IProviderUserRepository _providerUserRepository = Substitute.For<IProviderUserRepository>();
private readonly IStripeAdapter _stripeAdapter = Substitute.For<IStripeAdapter>();
private readonly ISubscriberService _subscriberService = Substitute.For<ISubscriberService>();
private readonly IUserRepository _userRepository = Substitute.For<IUserRepository>();
private BusinessUnitConverter BuildConverter() => new(
_dataProtectionProvider,
_globalSettings,
_logger,
_mailService,
_organizationRepository,
_organizationUserRepository,
_pricingClient,
_providerOrganizationRepository,
_providerPlanRepository,
_providerRepository,
_providerUserRepository,
_stripeAdapter,
_subscriberService,
_userRepository);
#region FinalizeConversion
[Theory, BitAutoData]
public async Task FinalizeConversion_Succeeds_ReturnsProviderId(
Organization organization,
Guid userId,
string providerKey,
string organizationKey)
{
organization.PlanType = PlanType.EnterpriseAnnually2020;
var enterpriseAnnually2020 = StaticStore.GetPlan(PlanType.EnterpriseAnnually2020);
var subscription = new Subscription
{
Id = "subscription_id",
CustomerId = "customer_id",
Status = StripeConstants.SubscriptionStatus.Active,
Items = new StripeList<SubscriptionItem>
{
Data = [
new SubscriptionItem
{
Id = "subscription_item_id",
Price = new Price
{
Id = enterpriseAnnually2020.PasswordManager.StripeSeatPlanId
}
}
]
}
};
_subscriberService.GetSubscription(organization).Returns(subscription);
var user = new User
{
Id = Guid.NewGuid(),
Email = "provider-admin@example.com"
};
_userRepository.GetByIdAsync(userId).Returns(user);
var token = SetupDataProtection(organization, user.Email);
var organizationUser = new OrganizationUser { Status = OrganizationUserStatusType.Confirmed };
_organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id)
.Returns(organizationUser);
var provider = new Provider
{
Type = ProviderType.BusinessUnit,
Status = ProviderStatusType.Pending
};
_providerRepository.GetByOrganizationIdAsync(organization.Id).Returns(provider);
var providerUser = new ProviderUser
{
Type = ProviderUserType.ProviderAdmin,
Status = ProviderUserStatusType.Invited
};
_providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser);
var providerOrganization = new ProviderOrganization();
_providerOrganizationRepository.GetByOrganizationId(organization.Id).Returns(providerOrganization);
_pricingClient.GetPlanOrThrow(PlanType.EnterpriseAnnually2020)
.Returns(enterpriseAnnually2020);
var enterpriseAnnually = StaticStore.GetPlan(PlanType.EnterpriseAnnually);
_pricingClient.GetPlanOrThrow(PlanType.EnterpriseAnnually)
.Returns(enterpriseAnnually);
var businessUnitConverter = BuildConverter();
await businessUnitConverter.FinalizeConversion(organization, userId, token, providerKey, organizationKey);
await _stripeAdapter.Received(2).CustomerUpdateAsync(subscription.CustomerId, Arg.Any<CustomerUpdateOptions>());
var updatedPriceId = ProviderPriceAdapter.GetActivePriceId(provider, enterpriseAnnually.Type);
await _stripeAdapter.Received(1).SubscriptionUpdateAsync(subscription.Id, Arg.Is<SubscriptionUpdateOptions>(
arguments =>
arguments.Items.Count == 2 &&
arguments.Items[0].Id == "subscription_item_id" &&
arguments.Items[0].Deleted == true &&
arguments.Items[1].Price == updatedPriceId &&
arguments.Items[1].Quantity == organization.Seats));
await _organizationRepository.Received(1).ReplaceAsync(Arg.Is<Organization>(arguments =>
arguments.PlanType == PlanType.EnterpriseAnnually &&
arguments.Status == OrganizationStatusType.Managed &&
arguments.GatewayCustomerId == null &&
arguments.GatewaySubscriptionId == null));
await _providerOrganizationRepository.Received(1).ReplaceAsync(Arg.Is<ProviderOrganization>(arguments =>
arguments.Key == organizationKey));
await _providerRepository.Received(1).ReplaceAsync(Arg.Is<Provider>(arguments =>
arguments.Gateway == GatewayType.Stripe &&
arguments.GatewayCustomerId == subscription.CustomerId &&
arguments.GatewaySubscriptionId == subscription.Id &&
arguments.Status == ProviderStatusType.Billable));
await _providerUserRepository.Received(1).ReplaceAsync(Arg.Is<ProviderUser>(arguments =>
arguments.Key == providerKey &&
arguments.Status == ProviderUserStatusType.Confirmed));
}
/*
* Because the validation for finalization is not an applicative like initialization is,
* I'm just testing one specific failure here. I don't see much value in testing every single opportunity for failure.
*/
[Theory, BitAutoData]
public async Task FinalizeConversion_ValidationFails_ThrowsBillingException(
Organization organization,
Guid userId,
string token,
string providerKey,
string organizationKey)
{
organization.PlanType = PlanType.EnterpriseAnnually2020;
var subscription = new Subscription
{
Status = StripeConstants.SubscriptionStatus.Canceled
};
_subscriberService.GetSubscription(organization).Returns(subscription);
var businessUnitConverter = BuildConverter();
await Assert.ThrowsAsync<BillingException>(() =>
businessUnitConverter.FinalizeConversion(organization, userId, token, providerKey, organizationKey));
await _organizationUserRepository.DidNotReceiveWithAnyArgs()
.GetByOrganizationAsync(Arg.Any<Guid>(), Arg.Any<Guid>());
}
#endregion
#region InitiateConversion
[Theory, BitAutoData]
public async Task InitiateConversion_Succeeds_ReturnsProviderId(
Organization organization,
string providerAdminEmail)
{
organization.PlanType = PlanType.EnterpriseAnnually;
_subscriberService.GetSubscription(organization).Returns(new Subscription
{
Status = StripeConstants.SubscriptionStatus.Active
});
var user = new User
{
Id = Guid.NewGuid(),
Email = providerAdminEmail
};
_userRepository.GetByEmailAsync(providerAdminEmail).Returns(user);
var organizationUser = new OrganizationUser { Status = OrganizationUserStatusType.Confirmed };
_organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id)
.Returns(organizationUser);
var provider = new Provider { Id = Guid.NewGuid() };
_providerRepository.CreateAsync(Arg.Is<Provider>(argument =>
argument.Name == organization.Name &&
argument.BillingEmail == organization.BillingEmail &&
argument.Status == ProviderStatusType.Pending &&
argument.Type == ProviderType.BusinessUnit)).Returns(provider);
var plan = StaticStore.GetPlan(organization.PlanType);
_pricingClient.GetPlanOrThrow(organization.PlanType).Returns(plan);
var token = SetupDataProtection(organization, providerAdminEmail);
var businessUnitConverter = BuildConverter();
var result = await businessUnitConverter.InitiateConversion(organization, providerAdminEmail);
Assert.True(result.IsT0);
var providerId = result.AsT0;
Assert.Equal(provider.Id, providerId);
await _providerOrganizationRepository.Received(1).CreateAsync(
Arg.Is<ProviderOrganization>(argument =>
argument.ProviderId == provider.Id &&
argument.OrganizationId == organization.Id));
await _providerPlanRepository.Received(1).CreateAsync(
Arg.Is<ProviderPlan>(argument =>
argument.ProviderId == provider.Id &&
argument.PlanType == PlanType.EnterpriseAnnually &&
argument.SeatMinimum == 0 &&
argument.PurchasedSeats == organization.Seats &&
argument.AllocatedSeats == organization.Seats));
await _providerUserRepository.Received(1).CreateAsync(
Arg.Is<ProviderUser>(argument =>
argument.ProviderId == provider.Id &&
argument.UserId == user.Id &&
argument.Email == user.Email &&
argument.Status == ProviderUserStatusType.Invited &&
argument.Type == ProviderUserType.ProviderAdmin));
await _mailService.Received(1).SendBusinessUnitConversionInviteAsync(
organization,
token,
user.Email);
}
[Theory, BitAutoData]
public async Task InitiateConversion_ValidationFails_ReturnsErrors(
Organization organization,
string providerAdminEmail)
{
organization.PlanType = PlanType.TeamsMonthly;
_subscriberService.GetSubscription(organization).Returns(new Subscription
{
Status = StripeConstants.SubscriptionStatus.Canceled
});
var user = new User
{
Id = Guid.NewGuid(),
Email = providerAdminEmail
};
_providerOrganizationRepository.GetByOrganizationId(organization.Id)
.Returns(new ProviderOrganization());
_userRepository.GetByEmailAsync(providerAdminEmail).Returns(user);
var organizationUser = new OrganizationUser { Status = OrganizationUserStatusType.Invited };
_organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id)
.Returns(organizationUser);
var businessUnitConverter = BuildConverter();
var result = await businessUnitConverter.InitiateConversion(organization, providerAdminEmail);
Assert.True(result.IsT1);
var problems = result.AsT1;
Assert.Contains("Organization must be on an enterprise plan.", problems);
Assert.Contains("Organization must have a valid subscription.", problems);
Assert.Contains("Organization is already linked to a provider.", problems);
Assert.Contains("Provider admin must be a confirmed member of the organization being converted.", problems);
}
#endregion
#region ResendConversionInvite
[Theory, BitAutoData]
public async Task ResendConversionInvite_ConversionInProgress_Succeeds(
Organization organization,
string providerAdminEmail)
{
SetupConversionInProgress(organization, providerAdminEmail);
var token = SetupDataProtection(organization, providerAdminEmail);
var businessUnitConverter = BuildConverter();
await businessUnitConverter.ResendConversionInvite(organization, providerAdminEmail);
await _mailService.Received(1).SendBusinessUnitConversionInviteAsync(
organization,
token,
providerAdminEmail);
}
[Theory, BitAutoData]
public async Task ResendConversionInvite_NoConversionInProgress_DoesNothing(
Organization organization,
string providerAdminEmail)
{
SetupDataProtection(organization, providerAdminEmail);
var businessUnitConverter = BuildConverter();
await businessUnitConverter.ResendConversionInvite(organization, providerAdminEmail);
await _mailService.DidNotReceiveWithAnyArgs().SendBusinessUnitConversionInviteAsync(
Arg.Any<Organization>(),
Arg.Any<string>(),
Arg.Any<string>());
}
#endregion
#region ResetConversion
[Theory, BitAutoData]
public async Task ResetConversion_ConversionInProgress_Succeeds(
Organization organization,
string providerAdminEmail)
{
var (provider, providerOrganization, providerUser, providerPlan) = SetupConversionInProgress(organization, providerAdminEmail);
var businessUnitConverter = BuildConverter();
await businessUnitConverter.ResetConversion(organization, providerAdminEmail);
await _providerOrganizationRepository.Received(1)
.DeleteAsync(providerOrganization);
await _providerUserRepository.Received(1)
.DeleteAsync(providerUser);
await _providerPlanRepository.Received(1)
.DeleteAsync(providerPlan);
await _providerRepository.Received(1)
.DeleteAsync(provider);
}
[Theory, BitAutoData]
public async Task ResetConversion_NoConversionInProgress_DoesNothing(
Organization organization,
string providerAdminEmail)
{
var businessUnitConverter = BuildConverter();
await businessUnitConverter.ResetConversion(organization, providerAdminEmail);
await _providerOrganizationRepository.DidNotReceiveWithAnyArgs()
.DeleteAsync(Arg.Any<ProviderOrganization>());
await _providerUserRepository.DidNotReceiveWithAnyArgs()
.DeleteAsync(Arg.Any<ProviderUser>());
await _providerPlanRepository.DidNotReceiveWithAnyArgs()
.DeleteAsync(Arg.Any<ProviderPlan>());
await _providerRepository.DidNotReceiveWithAnyArgs()
.DeleteAsync(Arg.Any<Provider>());
}
#endregion
#region Utilities
private string SetupDataProtection(
Organization organization,
string providerAdminEmail)
{
var dataProtector = new MockDataProtector(organization, providerAdminEmail);
_dataProtectionProvider.CreateProtector($"{nameof(BusinessUnitConverter)}DataProtector").Returns(dataProtector);
return dataProtector.Protect(dataProtector.Token);
}
private (Provider, ProviderOrganization, ProviderUser, ProviderPlan) SetupConversionInProgress(
Organization organization,
string providerAdminEmail)
{
var user = new User { Id = Guid.NewGuid() };
_userRepository.GetByEmailAsync(providerAdminEmail).Returns(user);
var provider = new Provider
{
Id = Guid.NewGuid(),
Type = ProviderType.BusinessUnit,
Status = ProviderStatusType.Pending
};
_providerRepository.GetByOrganizationIdAsync(organization.Id).Returns(provider);
var providerUser = new ProviderUser
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
UserId = user.Id,
Type = ProviderUserType.ProviderAdmin,
Status = ProviderUserStatusType.Invited,
Email = providerAdminEmail
};
_providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id)
.Returns(providerUser);
var providerOrganization = new ProviderOrganization
{
Id = Guid.NewGuid(),
OrganizationId = organization.Id,
ProviderId = provider.Id
};
_providerOrganizationRepository.GetByOrganizationId(organization.Id)
.Returns(providerOrganization);
var providerPlan = new ProviderPlan
{
Id = Guid.NewGuid(),
ProviderId = provider.Id,
PlanType = PlanType.EnterpriseAnnually
};
_providerPlanRepository.GetByProviderId(provider.Id).Returns([providerPlan]);
return (provider, providerOrganization, providerUser, providerPlan);
}
#endregion
}
public class MockDataProtector(
Organization organization,
string providerAdminEmail) : IDataProtector
{
public string Token = $"BusinessUnitConversionInvite {organization.Id} {providerAdminEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}";
public IDataProtector CreateProtector(string purpose) => this;
public byte[] Protect(byte[] plaintext) => Encoding.UTF8.GetBytes(Token);
public byte[] Unprotect(byte[] protectedData) => Encoding.UTF8.GetBytes(Token);
}

View File

@ -116,7 +116,7 @@ public class ProviderBillingServiceTests
SutProvider<ProviderBillingService> sutProvider) SutProvider<ProviderBillingService> sutProvider)
{ {
// Arrange // Arrange
provider.Type = ProviderType.MultiOrganizationEnterprise; provider.Type = ProviderType.BusinessUnit;
var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>(); var providerPlanRepository = sutProvider.GetDependency<IProviderPlanRepository>();
var existingPlan = new ProviderPlan var existingPlan = new ProviderPlan

View File

@ -71,7 +71,7 @@ public class ProviderPriceAdapterTests
var provider = new Provider var provider = new Provider
{ {
Id = Guid.NewGuid(), Id = Guid.NewGuid(),
Type = ProviderType.MultiOrganizationEnterprise Type = ProviderType.BusinessUnit
}; };
var subscription = new Subscription var subscription = new Subscription
@ -98,7 +98,7 @@ public class ProviderPriceAdapterTests
var provider = new Provider var provider = new Provider
{ {
Id = Guid.NewGuid(), Id = Guid.NewGuid(),
Type = ProviderType.MultiOrganizationEnterprise Type = ProviderType.BusinessUnit
}; };
var subscription = new Subscription var subscription = new Subscription
@ -141,7 +141,7 @@ public class ProviderPriceAdapterTests
var provider = new Provider var provider = new Provider
{ {
Id = Guid.NewGuid(), Id = Guid.NewGuid(),
Type = ProviderType.MultiOrganizationEnterprise Type = ProviderType.BusinessUnit
}; };
var result = ProviderPriceAdapter.GetActivePriceId(provider, planType); var result = ProviderPriceAdapter.GetActivePriceId(provider, planType);

View File

@ -14,9 +14,6 @@
<ProjectReference Include="..\Core\Core.csproj" /> <ProjectReference Include="..\Core\Core.csproj" />
<ProjectReference Include="..\..\util\SqliteMigrations\SqliteMigrations.csproj" /> <ProjectReference Include="..\..\util\SqliteMigrations\SqliteMigrations.csproj" />
</ItemGroup> </ItemGroup>
<ItemGroup>
<Folder Include="Billing\Controllers\" />
</ItemGroup>
<Choose> <Choose>
<When Condition="!$(DefineConstants.Contains('OSS'))"> <When Condition="!$(DefineConstants.Contains('OSS'))">

View File

@ -133,10 +133,10 @@ public class ProvidersController : Controller
return View(new CreateResellerProviderModel()); return View(new CreateResellerProviderModel());
} }
[HttpGet("providers/create/multi-organization-enterprise")] [HttpGet("providers/create/business-unit")]
public IActionResult CreateMultiOrganizationEnterprise(int enterpriseMinimumSeats, string ownerEmail = null) public IActionResult CreateBusinessUnit(int enterpriseMinimumSeats, string ownerEmail = null)
{ {
return View(new CreateMultiOrganizationEnterpriseProviderModel return View(new CreateBusinessUnitProviderModel
{ {
OwnerEmail = ownerEmail, OwnerEmail = ownerEmail,
EnterpriseSeatMinimum = enterpriseMinimumSeats EnterpriseSeatMinimum = enterpriseMinimumSeats
@ -157,7 +157,7 @@ public class ProvidersController : Controller
{ {
ProviderType.Msp => RedirectToAction("CreateMsp"), ProviderType.Msp => RedirectToAction("CreateMsp"),
ProviderType.Reseller => RedirectToAction("CreateReseller"), ProviderType.Reseller => RedirectToAction("CreateReseller"),
ProviderType.MultiOrganizationEnterprise => RedirectToAction("CreateMultiOrganizationEnterprise"), ProviderType.BusinessUnit => RedirectToAction("CreateBusinessUnit"),
_ => View(model) _ => View(model)
}; };
} }
@ -198,10 +198,10 @@ public class ProvidersController : Controller
return RedirectToAction("Edit", new { id = provider.Id }); return RedirectToAction("Edit", new { id = provider.Id });
} }
[HttpPost("providers/create/multi-organization-enterprise")] [HttpPost("providers/create/business-unit")]
[ValidateAntiForgeryToken] [ValidateAntiForgeryToken]
[RequirePermission(Permission.Provider_Create)] [RequirePermission(Permission.Provider_Create)]
public async Task<IActionResult> CreateMultiOrganizationEnterprise(CreateMultiOrganizationEnterpriseProviderModel model) public async Task<IActionResult> CreateBusinessUnit(CreateBusinessUnitProviderModel model)
{ {
if (!ModelState.IsValid) if (!ModelState.IsValid)
{ {
@ -209,7 +209,7 @@ public class ProvidersController : Controller
} }
var provider = model.ToProvider(); var provider = model.ToProvider();
await _createProviderCommand.CreateMultiOrganizationEnterpriseAsync( await _createProviderCommand.CreateBusinessUnitAsync(
provider, provider,
model.OwnerEmail, model.OwnerEmail,
model.Plan.Value, model.Plan.Value,
@ -307,7 +307,7 @@ public class ProvidersController : Controller
]); ]);
await _providerBillingService.UpdateSeatMinimums(updateMspSeatMinimumsCommand); await _providerBillingService.UpdateSeatMinimums(updateMspSeatMinimumsCommand);
break; break;
case ProviderType.MultiOrganizationEnterprise: case ProviderType.BusinessUnit:
{ {
var existingMoePlan = providerPlans.Single(); var existingMoePlan = providerPlans.Single();

View File

@ -6,7 +6,7 @@ using Bit.SharedWeb.Utilities;
namespace Bit.Admin.AdminConsole.Models; namespace Bit.Admin.AdminConsole.Models;
public class CreateMultiOrganizationEnterpriseProviderModel : IValidatableObject public class CreateBusinessUnitProviderModel : IValidatableObject
{ {
[Display(Name = "Owner Email")] [Display(Name = "Owner Email")]
public string OwnerEmail { get; set; } public string OwnerEmail { get; set; }
@ -22,7 +22,7 @@ public class CreateMultiOrganizationEnterpriseProviderModel : IValidatableObject
{ {
return new Provider return new Provider
{ {
Type = ProviderType.MultiOrganizationEnterprise Type = ProviderType.BusinessUnit
}; };
} }
@ -30,17 +30,17 @@ public class CreateMultiOrganizationEnterpriseProviderModel : IValidatableObject
{ {
if (string.IsNullOrWhiteSpace(OwnerEmail)) if (string.IsNullOrWhiteSpace(OwnerEmail))
{ {
var ownerEmailDisplayName = nameof(OwnerEmail).GetDisplayAttribute<CreateMultiOrganizationEnterpriseProviderModel>()?.GetName() ?? nameof(OwnerEmail); var ownerEmailDisplayName = nameof(OwnerEmail).GetDisplayAttribute<CreateBusinessUnitProviderModel>()?.GetName() ?? nameof(OwnerEmail);
yield return new ValidationResult($"The {ownerEmailDisplayName} field is required."); yield return new ValidationResult($"The {ownerEmailDisplayName} field is required.");
} }
if (EnterpriseSeatMinimum < 0) if (EnterpriseSeatMinimum < 0)
{ {
var enterpriseSeatMinimumDisplayName = nameof(EnterpriseSeatMinimum).GetDisplayAttribute<CreateMultiOrganizationEnterpriseProviderModel>()?.GetName() ?? nameof(EnterpriseSeatMinimum); var enterpriseSeatMinimumDisplayName = nameof(EnterpriseSeatMinimum).GetDisplayAttribute<CreateBusinessUnitProviderModel>()?.GetName() ?? nameof(EnterpriseSeatMinimum);
yield return new ValidationResult($"The {enterpriseSeatMinimumDisplayName} field can not be negative."); yield return new ValidationResult($"The {enterpriseSeatMinimumDisplayName} field can not be negative.");
} }
if (Plan != PlanType.EnterpriseAnnually && Plan != PlanType.EnterpriseMonthly) if (Plan != PlanType.EnterpriseAnnually && Plan != PlanType.EnterpriseMonthly)
{ {
var planDisplayName = nameof(Plan).GetDisplayAttribute<CreateMultiOrganizationEnterpriseProviderModel>()?.GetName() ?? nameof(Plan); var planDisplayName = nameof(Plan).GetDisplayAttribute<CreateBusinessUnitProviderModel>()?.GetName() ?? nameof(Plan);
yield return new ValidationResult($"The {planDisplayName} field must be set to Enterprise Annually or Enterprise Monthly."); yield return new ValidationResult($"The {planDisplayName} field must be set to Enterprise Annually or Enterprise Monthly.");
} }
} }

View File

@ -34,7 +34,7 @@ public class ProviderEditModel : ProviderViewModel, IValidatableObject
GatewaySubscriptionUrl = gatewaySubscriptionUrl; GatewaySubscriptionUrl = gatewaySubscriptionUrl;
Type = provider.Type; Type = provider.Type;
if (Type == ProviderType.MultiOrganizationEnterprise) if (Type == ProviderType.BusinessUnit)
{ {
var plan = providerPlans.SingleOrDefault(); var plan = providerPlans.SingleOrDefault();
EnterpriseMinimumSeats = plan?.SeatMinimum ?? 0; EnterpriseMinimumSeats = plan?.SeatMinimum ?? 0;
@ -100,7 +100,7 @@ public class ProviderEditModel : ProviderViewModel, IValidatableObject
yield return new ValidationResult($"The {billingEmailDisplayName} field is required."); yield return new ValidationResult($"The {billingEmailDisplayName} field is required.");
} }
break; break;
case ProviderType.MultiOrganizationEnterprise: case ProviderType.BusinessUnit:
if (Plan == null) if (Plan == null)
{ {
var displayName = nameof(Plan).GetDisplayAttribute<CreateProviderModel>()?.GetName() ?? nameof(Plan); var displayName = nameof(Plan).GetDisplayAttribute<CreateProviderModel>()?.GetName() ?? nameof(Plan);

View File

@ -40,7 +40,7 @@ public class ProviderViewModel
ProviderPlanViewModels.Add(new ProviderPlanViewModel("Enterprise (Monthly) Subscription", enterpriseProviderPlan, usedEnterpriseSeats)); ProviderPlanViewModels.Add(new ProviderPlanViewModel("Enterprise (Monthly) Subscription", enterpriseProviderPlan, usedEnterpriseSeats));
} }
} }
else if (Provider.Type == ProviderType.MultiOrganizationEnterprise) else if (Provider.Type == ProviderType.BusinessUnit)
{ {
var usedEnterpriseSeats = ProviderOrganizations.Where(po => po.PlanType == PlanType.EnterpriseMonthly) var usedEnterpriseSeats = ProviderOrganizations.Where(po => po.PlanType == PlanType.EnterpriseMonthly)
.Sum(po => po.OccupiedSeats).GetValueOrDefault(0); .Sum(po => po.OccupiedSeats).GetValueOrDefault(0);

View File

@ -1,8 +1,13 @@
@using Bit.Admin.Enums; @using Bit.Admin.Enums;
@using Bit.Admin.Models @using Bit.Admin.Models
@using Bit.Core
@using Bit.Core.AdminConsole.Enums.Provider
@using Bit.Core.Billing.Enums @using Bit.Core.Billing.Enums
@using Bit.Core.Enums @using Bit.Core.Billing.Extensions
@using Bit.Core.Services
@using Microsoft.AspNetCore.Mvc.TagHelpers
@inject Bit.Admin.Services.IAccessControlService AccessControlService @inject Bit.Admin.Services.IAccessControlService AccessControlService
@inject IFeatureService FeatureService
@model OrganizationEditModel @model OrganizationEditModel
@{ @{
ViewData["Title"] = (Model.Provider != null ? "Client " : string.Empty) + "Organization: " + Model.Name; ViewData["Title"] = (Model.Provider != null ? "Client " : string.Empty) + "Organization: " + Model.Name;
@ -13,6 +18,13 @@
var canRequestDelete = AccessControlService.UserHasPermission(Permission.Org_RequestDelete); var canRequestDelete = AccessControlService.UserHasPermission(Permission.Org_RequestDelete);
var canDelete = AccessControlService.UserHasPermission(Permission.Org_Delete); var canDelete = AccessControlService.UserHasPermission(Permission.Org_Delete);
var canUnlinkFromProvider = AccessControlService.UserHasPermission(Permission.Provider_Edit); var canUnlinkFromProvider = AccessControlService.UserHasPermission(Permission.Provider_Edit);
var canConvertToBusinessUnit =
FeatureService.IsEnabled(FeatureFlagKeys.PM18770_EnableOrganizationBusinessUnitConversion) &&
AccessControlService.UserHasPermission(Permission.Org_Billing_ConvertToBusinessUnit) &&
Model.Organization.PlanType.GetProductTier() == ProductTierType.Enterprise &&
!string.IsNullOrEmpty(Model.Organization.GatewaySubscriptionId) &&
Model.Provider is null or { Type: ProviderType.BusinessUnit, Status: ProviderStatusType.Pending };
} }
@section Scripts { @section Scripts {
@ -114,6 +126,15 @@
Enterprise Trial Enterprise Trial
</button> </button>
} }
@if (canConvertToBusinessUnit)
{
<a asp-controller="BusinessUnitConversion"
asp-action="Index"
asp-route-organizationId="@Model.Organization.Id"
class="btn btn-secondary me-2">
Convert to Business Unit
</a>
}
@if (canUnlinkFromProvider && Model.Provider is not null) @if (canUnlinkFromProvider && Model.Provider is not null)
{ {
<button class="btn btn-outline-danger me-2" <button class="btn btn-outline-danger me-2"

View File

@ -1,15 +1,15 @@
@using Bit.Core.Billing.Enums @using Bit.Core.Billing.Enums
@using Microsoft.AspNetCore.Mvc.TagHelpers @using Microsoft.AspNetCore.Mvc.TagHelpers
@model CreateMultiOrganizationEnterpriseProviderModel @model CreateBusinessUnitProviderModel
@{ @{
ViewData["Title"] = "Create Multi-organization Enterprise Provider"; ViewData["Title"] = "Create Business Unit Provider";
} }
<h1 class="mb-4">Create Multi-organization Enterprise Provider</h1> <h1 class="mb-4">Create Business Unit Provider</h1>
<div> <div>
<form method="post" asp-action="CreateMultiOrganizationEnterprise"> <form method="post" asp-action="CreateBusinessUnit">
<div asp-validation-summary="All" class="alert alert-danger"></div> <div asp-validation-summary="All" class="alert alert-danger"></div>
<div class="mb-3"> <div class="mb-3">
<label asp-for="OwnerEmail" class="form-label"></label> <label asp-for="OwnerEmail" class="form-label"></label>
@ -19,14 +19,14 @@
<div class="col-sm"> <div class="col-sm">
<div class="mb-3"> <div class="mb-3">
@{ @{
var multiOrgPlans = new List<PlanType> var businessUnitPlanTypes = new List<PlanType>
{ {
PlanType.EnterpriseAnnually, PlanType.EnterpriseAnnually,
PlanType.EnterpriseMonthly PlanType.EnterpriseMonthly
}; };
} }
<label asp-for="Plan" class="form-label"></label> <label asp-for="Plan" class="form-label"></label>
<select class="form-select" asp-for="Plan" asp-items="Html.GetEnumSelectList(multiOrgPlans)"> <select class="form-select" asp-for="Plan" asp-items="Html.GetEnumSelectList(businessUnitPlanTypes)">
<option value="">--</option> <option value="">--</option>
</select> </select>
</div> </div>

View File

@ -74,20 +74,20 @@
</div> </div>
break; break;
} }
case ProviderType.MultiOrganizationEnterprise: case ProviderType.BusinessUnit:
{ {
<div class="row"> <div class="row">
<div class="col-sm"> <div class="col-sm">
<div class="mb-3"> <div class="mb-3">
@{ @{
var multiOrgPlans = new List<PlanType> var businessUnitPlanTypes = new List<PlanType>
{ {
PlanType.EnterpriseAnnually, PlanType.EnterpriseAnnually,
PlanType.EnterpriseMonthly PlanType.EnterpriseMonthly
}; };
} }
<label asp-for="Plan" class="form-label"></label> <label asp-for="Plan" class="form-label"></label>
<select class="form-control" asp-for="Plan" asp-items="Html.GetEnumSelectList(multiOrgPlans)"> <select class="form-control" asp-for="Plan" asp-items="Html.GetEnumSelectList(businessUnitPlanTypes)">
<option value="">--</option> <option value="">--</option>
</select> </select>
</div> </div>

View File

@ -0,0 +1,185 @@
#nullable enable
using Bit.Admin.Billing.Models;
using Bit.Admin.Enums;
using Bit.Admin.Utilities;
using Bit.Core;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Billing.Services;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
namespace Bit.Admin.Billing.Controllers;
[Authorize]
[Route("organizations/billing/{organizationId:guid}/business-unit")]
[RequireFeature(FeatureFlagKeys.PM18770_EnableOrganizationBusinessUnitConversion)]
public class BusinessUnitConversionController(
IBusinessUnitConverter businessUnitConverter,
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
IProviderUserRepository providerUserRepository) : Controller
{
[HttpGet]
[RequirePermission(Permission.Org_Billing_ConvertToBusinessUnit)]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> IndexAsync([FromRoute] Guid organizationId)
{
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
var model = new BusinessUnitConversionModel { Organization = organization };
var invitedProviderAdmin = await GetInvitedProviderAdminAsync(organization);
if (invitedProviderAdmin != null)
{
model.ProviderAdminEmail = invitedProviderAdmin.Email;
model.ProviderId = invitedProviderAdmin.ProviderId;
}
var success = ReadSuccessMessage();
if (!string.IsNullOrEmpty(success))
{
model.Success = success;
}
var errors = ReadErrorMessages();
if (errors is { Count: > 0 })
{
model.Errors = errors;
}
return View(model);
}
[HttpPost]
[RequirePermission(Permission.Org_Billing_ConvertToBusinessUnit)]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> InitiateAsync(
[FromRoute] Guid organizationId,
BusinessUnitConversionModel model)
{
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
var result = await businessUnitConverter.InitiateConversion(
organization,
model.ProviderAdminEmail!);
return result.Match(
providerId => RedirectToAction("Edit", "Providers", new { id = providerId }),
errors =>
{
PersistErrorMessages(errors);
return RedirectToAction("Index", new { organizationId });
});
}
[HttpPost("reset")]
[RequirePermission(Permission.Org_Billing_ConvertToBusinessUnit)]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> ResetAsync(
[FromRoute] Guid organizationId,
BusinessUnitConversionModel model)
{
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
await businessUnitConverter.ResetConversion(organization, model.ProviderAdminEmail!);
PersistSuccessMessage("Business unit conversion was successfully reset.");
return RedirectToAction("Index", new { organizationId });
}
[HttpPost("resend-invite")]
[RequirePermission(Permission.Org_Billing_ConvertToBusinessUnit)]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IActionResult> ResendInviteAsync(
[FromRoute] Guid organizationId,
BusinessUnitConversionModel model)
{
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
throw new NotFoundException();
}
await businessUnitConverter.ResendConversionInvite(organization, model.ProviderAdminEmail!);
PersistSuccessMessage($"Invite was successfully resent to {model.ProviderAdminEmail}.");
return RedirectToAction("Index", new { organizationId });
}
private async Task<ProviderUser?> GetInvitedProviderAdminAsync(
Organization organization)
{
var provider = await providerRepository.GetByOrganizationIdAsync(organization.Id);
if (provider is not
{
Type: ProviderType.BusinessUnit,
Status: ProviderStatusType.Pending
})
{
return null;
}
var providerUsers =
await providerUserRepository.GetManyByProviderAsync(provider.Id, ProviderUserType.ProviderAdmin);
if (providerUsers.Count != 1)
{
return null;
}
var providerUser = providerUsers.First();
return providerUser is
{
Type: ProviderUserType.ProviderAdmin,
Status: ProviderUserStatusType.Invited,
UserId: not null
} ? providerUser : null;
}
private const string _errors = "errors";
private const string _success = "Success";
private void PersistSuccessMessage(string message) => TempData[_success] = message;
private void PersistErrorMessages(List<string> errors)
{
var input = string.Join("|", errors);
TempData[_errors] = input;
}
private string? ReadSuccessMessage() => ReadTempData<string>(_success);
private List<string>? ReadErrorMessages()
{
var output = ReadTempData<string>(_errors);
return string.IsNullOrEmpty(output) ? null : output.Split('|').ToList();
}
private T? ReadTempData<T>(string key) => TempData.TryGetValue(key, out var obj) && obj is T value ? value : default;
}

View File

@ -0,0 +1,25 @@
#nullable enable
using System.ComponentModel.DataAnnotations;
using Bit.Core.AdminConsole.Entities;
using Microsoft.AspNetCore.Mvc.ModelBinding;
namespace Bit.Admin.Billing.Models;
public class BusinessUnitConversionModel
{
[Required]
[EmailAddress]
[Display(Name = "Provider Admin Email")]
public string? ProviderAdminEmail { get; set; }
[BindNever]
public required Organization Organization { get; set; }
[BindNever]
public Guid? ProviderId { get; set; }
[BindNever]
public string? Success { get; set; }
[BindNever] public List<string>? Errors { get; set; } = [];
}

View File

@ -0,0 +1,75 @@
@model Bit.Admin.Billing.Models.BusinessUnitConversionModel
@{
ViewData["Title"] = "Convert Organization to Business Unit";
}
@if (!string.IsNullOrEmpty(Model.ProviderAdminEmail))
{
<h1>Convert @Model.Organization.Name to Business Unit</h1>
@if (!string.IsNullOrEmpty(Model.Success))
{
<div class="alert alert-success alert-dismissible fade show mb-3" role="alert">
@Model.Success
<button type="button" class="btn-close" data-bs-dismiss="alert" aria-label="Close"></button>
</div>
}
@if (Model.Errors?.Any() ?? false)
{
@foreach (var error in Model.Errors)
{
<div class="alert alert-danger alert-dismissible fade show mb-3" role="alert">
@error
<button type="button" class="btn-close" data-bs-dismiss="alert" aria-label="Close"></button>
</div>
}
}
<p>This organization has a business unit conversion in progress.</p>
<div class="mb-3">
<label asp-for="ProviderAdminEmail" class="form-label"></label>
<input type="email" class="form-control" asp-for="ProviderAdminEmail" disabled></input>
</div>
<div class="d-flex gap-2">
<form method="post" asp-controller="BusinessUnitConversion" asp-action="ResendInvite" asp-route-organizationId="@Model.Organization.Id">
<input type="hidden" asp-for="ProviderAdminEmail" />
<button type="submit" class="btn btn-primary mb-2">Resend Invite</button>
</form>
<form method="post" asp-controller="BusinessUnitConversion" asp-action="Reset" asp-route-organizationId="@Model.Organization.Id">
<input type="hidden" asp-for="ProviderAdminEmail" />
<button type="submit" class="btn btn-danger mb-2">Reset Conversion</button>
</form>
@if (Model.ProviderId.HasValue)
{
<a asp-controller="Providers"
asp-action="Edit"
asp-route-id="@Model.ProviderId"
class="btn btn-secondary mb-2">
Go to Provider
</a>
}
</div>
}
else
{
<h1>Convert @Model.Organization.Name to Business Unit</h1>
@if (Model.Errors?.Any() ?? false)
{
@foreach (var error in Model.Errors)
{
<div class="alert alert-danger alert-dismissible fade show mb-3" role="alert">
@error
<button type="button" class="btn-close" data-bs-dismiss="alert" aria-label="Close"></button>
</div>
}
}
<form method="post" asp-controller="BusinessUnitConversion" asp-action="Initiate" asp-route-organizationId="@Model.Organization.Id">
<div asp-validation-summary="All" class="alert alert-danger"></div>
<div class="mb-3">
<label asp-for="ProviderAdminEmail" class="form-label"></label>
<input type="email" class="form-control" asp-for="ProviderAdminEmail" />
</div>
<button type="submit" class="btn btn-primary mb-2">Convert</button>
</form>
}

View File

@ -38,6 +38,7 @@ public enum Permission
Org_Billing_View, Org_Billing_View,
Org_Billing_Edit, Org_Billing_Edit,
Org_Billing_LaunchGateway, Org_Billing_LaunchGateway,
Org_Billing_ConvertToBusinessUnit,
Provider_List_View, Provider_List_View,
Provider_Create, Provider_Create,

View File

@ -42,6 +42,7 @@ public static class RolePermissionMapping
Permission.Org_Billing_View, Permission.Org_Billing_View,
Permission.Org_Billing_Edit, Permission.Org_Billing_Edit,
Permission.Org_Billing_LaunchGateway, Permission.Org_Billing_LaunchGateway,
Permission.Org_Billing_ConvertToBusinessUnit,
Permission.Provider_List_View, Permission.Provider_List_View,
Permission.Provider_Create, Permission.Provider_Create,
Permission.Provider_View, Permission.Provider_View,
@ -90,6 +91,7 @@ public static class RolePermissionMapping
Permission.Org_Billing_View, Permission.Org_Billing_View,
Permission.Org_Billing_Edit, Permission.Org_Billing_Edit,
Permission.Org_Billing_LaunchGateway, Permission.Org_Billing_LaunchGateway,
Permission.Org_Billing_ConvertToBusinessUnit,
Permission.Org_InitiateTrial, Permission.Org_InitiateTrial,
Permission.Provider_List_View, Permission.Provider_List_View,
Permission.Provider_Create, Permission.Provider_Create,
@ -166,6 +168,7 @@ public static class RolePermissionMapping
Permission.Org_Billing_View, Permission.Org_Billing_View,
Permission.Org_Billing_Edit, Permission.Org_Billing_Edit,
Permission.Org_Billing_LaunchGateway, Permission.Org_Billing_LaunchGateway,
Permission.Org_Billing_ConvertToBusinessUnit,
Permission.Org_RequestDelete, Permission.Org_RequestDelete,
Permission.Provider_Edit, Permission.Provider_Edit,
Permission.Provider_View, Permission.Provider_View,

View File

@ -65,6 +65,7 @@ public class OrganizationsController : Controller
private readonly IOrganizationDeleteCommand _organizationDeleteCommand; private readonly IOrganizationDeleteCommand _organizationDeleteCommand;
private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
private readonly IOrganizationUpdateKeysCommand _organizationUpdateKeysCommand;
public OrganizationsController( public OrganizationsController(
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
@ -88,7 +89,8 @@ public class OrganizationsController : Controller
ICloudOrganizationSignUpCommand cloudOrganizationSignUpCommand, ICloudOrganizationSignUpCommand cloudOrganizationSignUpCommand,
IOrganizationDeleteCommand organizationDeleteCommand, IOrganizationDeleteCommand organizationDeleteCommand,
IPolicyRequirementQuery policyRequirementQuery, IPolicyRequirementQuery policyRequirementQuery,
IPricingClient pricingClient) IPricingClient pricingClient,
IOrganizationUpdateKeysCommand organizationUpdateKeysCommand)
{ {
_organizationRepository = organizationRepository; _organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository; _organizationUserRepository = organizationUserRepository;
@ -112,6 +114,7 @@ public class OrganizationsController : Controller
_organizationDeleteCommand = organizationDeleteCommand; _organizationDeleteCommand = organizationDeleteCommand;
_policyRequirementQuery = policyRequirementQuery; _policyRequirementQuery = policyRequirementQuery;
_pricingClient = pricingClient; _pricingClient = pricingClient;
_organizationUpdateKeysCommand = organizationUpdateKeysCommand;
} }
[HttpGet("{id}")] [HttpGet("{id}")]
@ -490,7 +493,7 @@ public class OrganizationsController : Controller
} }
[HttpPost("{id}/keys")] [HttpPost("{id}/keys")]
public async Task<OrganizationKeysResponseModel> PostKeys(string id, [FromBody] OrganizationKeysRequestModel model) public async Task<OrganizationKeysResponseModel> PostKeys(Guid id, [FromBody] OrganizationKeysRequestModel model)
{ {
var user = await _userService.GetUserByPrincipalAsync(User); var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null) if (user == null)
@ -498,7 +501,7 @@ public class OrganizationsController : Controller
throw new UnauthorizedAccessException(); throw new UnauthorizedAccessException();
} }
var org = await _organizationService.UpdateOrganizationKeysAsync(new Guid(id), model.PublicKey, var org = await _organizationUpdateKeysCommand.UpdateOrganizationKeysAsync(id, model.PublicKey,
model.EncryptedPrivateKey); model.EncryptedPrivateKey);
return new OrganizationKeysResponseModel(org); return new OrganizationKeysResponseModel(org);
} }

View File

@ -67,10 +67,12 @@ public class OrganizationUserDetailsResponseModel : OrganizationUserResponseMode
public OrganizationUserDetailsResponseModel( public OrganizationUserDetailsResponseModel(
OrganizationUser organizationUser, OrganizationUser organizationUser,
bool claimedByOrganization, bool claimedByOrganization,
string ssoExternalId,
IEnumerable<CollectionAccessSelection> collections) IEnumerable<CollectionAccessSelection> collections)
: base(organizationUser, "organizationUserDetails") : base(organizationUser, "organizationUserDetails")
{ {
ClaimedByOrganization = claimedByOrganization; ClaimedByOrganization = claimedByOrganization;
SsoExternalId = ssoExternalId;
Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c)); Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c));
} }
@ -80,6 +82,7 @@ public class OrganizationUserDetailsResponseModel : OrganizationUserResponseMode
: base(organizationUser, "organizationUserDetails") : base(organizationUser, "organizationUserDetails")
{ {
ClaimedByOrganization = claimedByOrganization; ClaimedByOrganization = claimedByOrganization;
SsoExternalId = organizationUser.SsoExternalId;
Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c)); Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c));
} }
@ -90,6 +93,7 @@ public class OrganizationUserDetailsResponseModel : OrganizationUserResponseMode
set => ClaimedByOrganization = value; set => ClaimedByOrganization = value;
} }
public bool ClaimedByOrganization { get; set; } public bool ClaimedByOrganization { get; set; }
public string SsoExternalId { get; set; }
public IEnumerable<SelectionReadOnlyResponseModel> Collections { get; set; } public IEnumerable<SelectionReadOnlyResponseModel> Collections { get; set; }

View File

@ -22,6 +22,7 @@ public class ProfileProviderResponseModel : ResponseModel
UserId = provider.UserId; UserId = provider.UserId;
UseEvents = provider.UseEvents; UseEvents = provider.UseEvents;
ProviderStatus = provider.ProviderStatus; ProviderStatus = provider.ProviderStatus;
ProviderType = provider.ProviderType;
} }
public Guid Id { get; set; } public Guid Id { get; set; }
@ -35,4 +36,5 @@ public class ProfileProviderResponseModel : ResponseModel
public Guid? UserId { get; set; } public Guid? UserId { get; set; }
public bool UseEvents { get; set; } public bool UseEvents { get; set; }
public ProviderStatusType ProviderStatus { get; set; } public ProviderStatusType ProviderStatus { get; set; }
public ProviderType ProviderType { get; set; }
} }

View File

@ -0,0 +1,11 @@
using System.ComponentModel.DataAnnotations;
#nullable enable
namespace Bit.Api.Auth.Models.Request;
public class UntrustDevicesRequestModel
{
[Required]
public IEnumerable<Guid> Devices { get; set; } = null!;
}

View File

@ -2,6 +2,7 @@
using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Request.Organizations;
using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Requests;
using Bit.Api.Billing.Models.Responses; using Bit.Api.Billing.Models.Responses;
using Bit.Core;
using Bit.Core.Billing.Models; using Bit.Core.Billing.Models;
using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Models.Sales;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
@ -18,7 +19,9 @@ namespace Bit.Api.Billing.Controllers;
[Route("organizations/{organizationId:guid}/billing")] [Route("organizations/{organizationId:guid}/billing")]
[Authorize("Application")] [Authorize("Application")]
public class OrganizationBillingController( public class OrganizationBillingController(
IBusinessUnitConverter businessUnitConverter,
ICurrentContext currentContext, ICurrentContext currentContext,
IFeatureService featureService,
IOrganizationBillingService organizationBillingService, IOrganizationBillingService organizationBillingService,
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
IPaymentService paymentService, IPaymentService paymentService,
@ -296,4 +299,40 @@ public class OrganizationBillingController(
return TypedResults.Ok(); return TypedResults.Ok();
} }
[HttpPost("setup-business-unit")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<IResult> SetupBusinessUnitAsync(
[FromRoute] Guid organizationId,
[FromBody] SetupBusinessUnitRequestBody requestBody)
{
var enableOrganizationBusinessUnitConversion =
featureService.IsEnabled(FeatureFlagKeys.PM18770_EnableOrganizationBusinessUnitConversion);
if (!enableOrganizationBusinessUnitConversion)
{
return Error.NotFound();
}
var organization = await organizationRepository.GetByIdAsync(organizationId);
if (organization == null)
{
return Error.NotFound();
}
if (!await currentContext.OrganizationUser(organizationId))
{
return Error.Unauthorized();
}
var providerId = await businessUnitConverter.FinalizeConversion(
organization,
requestBody.UserId,
requestBody.Token,
requestBody.ProviderKey,
requestBody.OrganizationKey);
return TypedResults.Ok(providerId);
}
} }

View File

@ -0,0 +1,18 @@
using System.ComponentModel.DataAnnotations;
namespace Bit.Api.Billing.Models.Requests;
public class SetupBusinessUnitRequestBody
{
[Required]
public Guid UserId { get; set; }
[Required]
public string Token { get; set; }
[Required]
public string ProviderKey { get; set; }
[Required]
public string OrganizationKey { get; set; }
}

View File

@ -4,6 +4,7 @@ using Bit.Api.Models.Request;
using Bit.Api.Models.Response; using Bit.Api.Models.Response;
using Bit.Core.Auth.Models.Api.Request; using Bit.Core.Auth.Models.Api.Request;
using Bit.Core.Auth.Models.Api.Response; using Bit.Core.Auth.Models.Api.Response;
using Bit.Core.Auth.UserFeatures.DeviceTrust;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Exceptions; using Bit.Core.Exceptions;
using Bit.Core.Repositories; using Bit.Core.Repositories;
@ -21,6 +22,7 @@ public class DevicesController : Controller
private readonly IDeviceRepository _deviceRepository; private readonly IDeviceRepository _deviceRepository;
private readonly IDeviceService _deviceService; private readonly IDeviceService _deviceService;
private readonly IUserService _userService; private readonly IUserService _userService;
private readonly IUntrustDevicesCommand _untrustDevicesCommand;
private readonly IUserRepository _userRepository; private readonly IUserRepository _userRepository;
private readonly ICurrentContext _currentContext; private readonly ICurrentContext _currentContext;
private readonly ILogger<DevicesController> _logger; private readonly ILogger<DevicesController> _logger;
@ -29,6 +31,7 @@ public class DevicesController : Controller
IDeviceRepository deviceRepository, IDeviceRepository deviceRepository,
IDeviceService deviceService, IDeviceService deviceService,
IUserService userService, IUserService userService,
IUntrustDevicesCommand untrustDevicesCommand,
IUserRepository userRepository, IUserRepository userRepository,
ICurrentContext currentContext, ICurrentContext currentContext,
ILogger<DevicesController> logger) ILogger<DevicesController> logger)
@ -36,6 +39,7 @@ public class DevicesController : Controller
_deviceRepository = deviceRepository; _deviceRepository = deviceRepository;
_deviceService = deviceService; _deviceService = deviceService;
_userService = userService; _userService = userService;
_untrustDevicesCommand = untrustDevicesCommand;
_userRepository = userRepository; _userRepository = userRepository;
_currentContext = currentContext; _currentContext = currentContext;
_logger = logger; _logger = logger;
@ -165,6 +169,19 @@ public class DevicesController : Controller
model.OtherDevices ?? Enumerable.Empty<OtherDeviceKeysUpdateRequestModel>()); model.OtherDevices ?? Enumerable.Empty<OtherDeviceKeysUpdateRequestModel>());
} }
[HttpPost("untrust")]
public async Task PostUntrust([FromBody] UntrustDevicesRequestModel model)
{
var user = await _userService.GetUserByPrincipalAsync(User);
if (user == null)
{
throw new UnauthorizedAccessException();
}
await _untrustDevicesCommand.UntrustDevices(user, model.Devices);
}
[HttpPut("identifier/{identifier}/token")] [HttpPut("identifier/{identifier}/token")]
[HttpPost("identifier/{identifier}/token")] [HttpPost("identifier/{identifier}/token")]
public async Task PutToken(string identifier, [FromBody] DeviceTokenRequestModel model) public async Task PutToken(string identifier, [FromBody] DeviceTokenRequestModel model)

View File

@ -1,6 +1,5 @@
using Bit.Billing.Constants; using Bit.Billing.Constants;
using Bit.Billing.Jobs; using Bit.Billing.Jobs;
using Bit.Core;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Pricing;
using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces;
@ -24,7 +23,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
private readonly IPushNotificationService _pushNotificationService; private readonly IPushNotificationService _pushNotificationService;
private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationRepository _organizationRepository;
private readonly ISchedulerFactory _schedulerFactory; private readonly ISchedulerFactory _schedulerFactory;
private readonly IFeatureService _featureService;
private readonly IOrganizationEnableCommand _organizationEnableCommand; private readonly IOrganizationEnableCommand _organizationEnableCommand;
private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
@ -39,7 +37,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
IPushNotificationService pushNotificationService, IPushNotificationService pushNotificationService,
IOrganizationRepository organizationRepository, IOrganizationRepository organizationRepository,
ISchedulerFactory schedulerFactory, ISchedulerFactory schedulerFactory,
IFeatureService featureService,
IOrganizationEnableCommand organizationEnableCommand, IOrganizationEnableCommand organizationEnableCommand,
IOrganizationDisableCommand organizationDisableCommand, IOrganizationDisableCommand organizationDisableCommand,
IPricingClient pricingClient) IPricingClient pricingClient)
@ -53,7 +50,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
_pushNotificationService = pushNotificationService; _pushNotificationService = pushNotificationService;
_organizationRepository = organizationRepository; _organizationRepository = organizationRepository;
_schedulerFactory = schedulerFactory; _schedulerFactory = schedulerFactory;
_featureService = featureService;
_organizationEnableCommand = organizationEnableCommand; _organizationEnableCommand = organizationEnableCommand;
_organizationDisableCommand = organizationDisableCommand; _organizationDisableCommand = organizationDisableCommand;
_pricingClient = pricingClient; _pricingClient = pricingClient;
@ -227,12 +223,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
private async Task ScheduleCancellationJobAsync(string subscriptionId, Guid organizationId) private async Task ScheduleCancellationJobAsync(string subscriptionId, Guid organizationId)
{ {
var isResellerManagedOrgAlertEnabled = _featureService.IsEnabled(FeatureFlagKeys.ResellerManagedOrgAlert);
if (!isResellerManagedOrgAlertEnabled)
{
return;
}
var scheduler = await _schedulerFactory.GetScheduler(); var scheduler = await _schedulerFactory.GetScheduler();
var job = JobBuilder.Create<SubscriptionCancellationJob>() var job = JobBuilder.Create<SubscriptionCancellationJob>()

View File

@ -8,6 +8,6 @@ public enum ProviderType : byte
Msp = 0, Msp = 0,
[Display(ShortName = "Reseller", Name = "Reseller", Description = "Creates Bitwarden Portal page for client organization billing management", Order = 1000)] [Display(ShortName = "Reseller", Name = "Reseller", Description = "Creates Bitwarden Portal page for client organization billing management", Order = 1000)]
Reseller = 1, Reseller = 1,
[Display(ShortName = "MOE", Name = "Multi-organization Enterprises", Description = "Creates provider portal for multi-organization management", Order = 1)] [Display(ShortName = "Business Unit", Name = "Business Unit", Description = "Creates provider portal for business unit management", Order = 1)]
MultiOrganizationEnterprise = 2, BusinessUnit = 2,
} }

View File

@ -17,4 +17,5 @@ public class ProviderUserProviderDetails
public string Permissions { get; set; } public string Permissions { get; set; }
public bool UseEvents { get; set; } public bool UseEvents { get; set; }
public ProviderStatusType ProviderStatus { get; set; } public ProviderStatusType ProviderStatus { get; set; }
public ProviderType ProviderType { get; set; }
} }

View File

@ -0,0 +1,13 @@
using Bit.Core.AdminConsole.Entities;
public interface IOrganizationUpdateKeysCommand
{
/// <summary>
/// Update the keys for an organization.
/// </summary>
/// <param name="orgId">The ID of the organization to update.</param>
/// <param name="publicKey">The public key for the organization.</param>
/// <param name="privateKey">The private key for the organization.</param>
/// <returns>The updated organization.</returns>
Task<Organization> UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey);
}

View File

@ -0,0 +1,47 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
public class OrganizationUpdateKeysCommand : IOrganizationUpdateKeysCommand
{
private readonly ICurrentContext _currentContext;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationService _organizationService;
public const string OrganizationKeysAlreadyExistErrorMessage = "Organization Keys already exist.";
public OrganizationUpdateKeysCommand(
ICurrentContext currentContext,
IOrganizationRepository organizationRepository,
IOrganizationService organizationService)
{
_currentContext = currentContext;
_organizationRepository = organizationRepository;
_organizationService = organizationService;
}
public async Task<Organization> UpdateOrganizationKeysAsync(Guid organizationId, string publicKey, string privateKey)
{
if (!await _currentContext.ManageResetPassword(organizationId))
{
throw new UnauthorizedAccessException();
}
// If the keys already exist, error out
var organization = await _organizationRepository.GetByIdAsync(organizationId);
if (organization.PublicKey != null && organization.PrivateKey != null)
{
throw new BadRequestException(OrganizationKeysAlreadyExistErrorMessage);
}
// Update org with generated public/private key
organization.PublicKey = publicKey;
organization.PrivateKey = privateKey;
await _organizationService.UpdateAsync(organization);
return organization;
}
}

View File

@ -7,5 +7,5 @@ public interface ICreateProviderCommand
{ {
Task CreateMspAsync(Provider provider, string ownerEmail, int teamsMinimumSeats, int enterpriseMinimumSeats); Task CreateMspAsync(Provider provider, string ownerEmail, int teamsMinimumSeats, int enterpriseMinimumSeats);
Task CreateResellerAsync(Provider provider); Task CreateResellerAsync(Provider provider);
Task CreateMultiOrganizationEnterpriseAsync(Provider provider, string ownerEmail, PlanType plan, int minimumSeats); Task CreateBusinessUnitAsync(Provider provider, string ownerEmail, PlanType plan, int minimumSeats);
} }

View File

@ -41,7 +41,6 @@ public interface IOrganizationService
IEnumerable<ImportedOrganizationUser> newUsers, IEnumerable<string> removeUserExternalIds, IEnumerable<ImportedOrganizationUser> newUsers, IEnumerable<string> removeUserExternalIds,
bool overwriteExisting, EventSystemUser eventSystemUser); bool overwriteExisting, EventSystemUser eventSystemUser);
Task DeleteSsoUserAsync(Guid userId, Guid? organizationId); Task DeleteSsoUserAsync(Guid userId, Guid? organizationId);
Task<Organization> UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey);
Task RevokeUserAsync(OrganizationUser organizationUser, Guid? revokingUserId); Task RevokeUserAsync(OrganizationUser organizationUser, Guid? revokingUserId);
Task RevokeUserAsync(OrganizationUser organizationUser, EventSystemUser systemUser); Task RevokeUserAsync(OrganizationUser organizationUser, EventSystemUser systemUser);
Task<List<Tuple<OrganizationUser, string>>> RevokeUsersAsync(Guid organizationId, Task<List<Tuple<OrganizationUser, string>>> RevokeUsersAsync(Guid organizationId,

View File

@ -1397,28 +1397,6 @@ public class OrganizationService : IOrganizationService
} }
} }
public async Task<Organization> UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey)
{
if (!await _currentContext.ManageResetPassword(orgId))
{
throw new UnauthorizedAccessException();
}
// If the keys already exist, error out
var org = await _organizationRepository.GetByIdAsync(orgId);
if (org.PublicKey != null && org.PrivateKey != null)
{
throw new BadRequestException("Organization Keys already exist");
}
// Update org with generated public/private key
org.PublicKey = publicKey;
org.PrivateKey = privateKey;
await UpdateAsync(org);
return org;
}
private async Task UpdateUsersAsync(Group group, HashSet<string> groupUsers, private async Task UpdateUsersAsync(Group group, HashSet<string> groupUsers,
Dictionary<string, Guid> existingUsersIdDict, HashSet<Guid> existingUsers = null) Dictionary<string, Guid> existingUsersIdDict, HashSet<Guid> existingUsers = null)
{ {

View File

@ -0,0 +1,8 @@
using Bit.Core.Entities;
namespace Bit.Core.Auth.UserFeatures.DeviceTrust;
public interface IUntrustDevicesCommand
{
public Task UntrustDevices(User user, IEnumerable<Guid> devicesToUntrust);
}

View File

@ -0,0 +1,39 @@
using Bit.Core.Entities;
using Bit.Core.Repositories;
namespace Bit.Core.Auth.UserFeatures.DeviceTrust;
public class UntrustDevicesCommand : IUntrustDevicesCommand
{
private readonly IDeviceRepository _deviceRepository;
public UntrustDevicesCommand(
IDeviceRepository deviceRepository)
{
_deviceRepository = deviceRepository;
}
public async Task UntrustDevices(User user, IEnumerable<Guid> devicesToUntrust)
{
var userDevices = await _deviceRepository.GetManyByUserIdAsync(user.Id);
var deviceIdDict = userDevices.ToDictionary(device => device.Id);
// Validate that the user owns all devices that they passed in
foreach (var deviceId in devicesToUntrust)
{
if (!deviceIdDict.ContainsKey(deviceId))
{
throw new UnauthorizedAccessException($"User {user.Id} does not have access to device {deviceId}");
}
}
foreach (var deviceId in devicesToUntrust)
{
var device = deviceIdDict[deviceId];
device.EncryptedPrivateKey = null;
device.EncryptedPublicKey = null;
device.EncryptedUserKey = null;
await _deviceRepository.UpsertAsync(device);
}
}
}

View File

@ -1,5 +1,6 @@
 
using Bit.Core.Auth.UserFeatures.DeviceTrust;
using Bit.Core.Auth.UserFeatures.Registration; using Bit.Core.Auth.UserFeatures.Registration;
using Bit.Core.Auth.UserFeatures.Registration.Implementations; using Bit.Core.Auth.UserFeatures.Registration.Implementations;
using Bit.Core.Auth.UserFeatures.TdeOffboardingPassword.Interfaces; using Bit.Core.Auth.UserFeatures.TdeOffboardingPassword.Interfaces;
@ -22,6 +23,7 @@ public static class UserServiceCollectionExtensions
public static void AddUserServices(this IServiceCollection services, IGlobalSettings globalSettings) public static void AddUserServices(this IServiceCollection services, IGlobalSettings globalSettings)
{ {
services.AddScoped<IUserService, UserService>(); services.AddScoped<IUserService, UserService>();
services.AddDeviceTrustCommands();
services.AddUserPasswordCommands(); services.AddUserPasswordCommands();
services.AddUserRegistrationCommands(); services.AddUserRegistrationCommands();
services.AddWebAuthnLoginCommands(); services.AddWebAuthnLoginCommands();
@ -29,6 +31,11 @@ public static class UserServiceCollectionExtensions
services.AddTwoFactorQueries(); services.AddTwoFactorQueries();
} }
public static void AddDeviceTrustCommands(this IServiceCollection services)
{
services.AddScoped<IUntrustDevicesCommand, UntrustDevicesCommand>();
}
public static void AddUserKeyCommands(this IServiceCollection services, IGlobalSettings globalSettings) public static void AddUserKeyCommands(this IServiceCollection services, IGlobalSettings globalSettings)
{ {
services.AddScoped<IRotateUserKeyCommand, RotateUserKeyCommand>(); services.AddScoped<IRotateUserKeyCommand, RotateUserKeyCommand>();

View File

@ -25,19 +25,19 @@ public static class BillingExtensions
public static bool IsBillable(this Provider provider) => public static bool IsBillable(this Provider provider) =>
provider is provider is
{ {
Type: ProviderType.Msp or ProviderType.MultiOrganizationEnterprise, Type: ProviderType.Msp or ProviderType.BusinessUnit,
Status: ProviderStatusType.Billable Status: ProviderStatusType.Billable
}; };
public static bool IsBillable(this InviteOrganizationProvider inviteOrganizationProvider) => public static bool IsBillable(this InviteOrganizationProvider inviteOrganizationProvider) =>
inviteOrganizationProvider is inviteOrganizationProvider is
{ {
Type: ProviderType.Msp or ProviderType.MultiOrganizationEnterprise, Type: ProviderType.Msp or ProviderType.BusinessUnit,
Status: ProviderStatusType.Billable Status: ProviderStatusType.Billable
}; };
public static bool SupportsConsolidatedBilling(this ProviderType providerType) public static bool SupportsConsolidatedBilling(this ProviderType providerType)
=> providerType is ProviderType.Msp or ProviderType.MultiOrganizationEnterprise; => providerType is ProviderType.Msp or ProviderType.BusinessUnit;
public static bool IsValidClient(this Organization organization) public static bool IsValidClient(this Organization organization)
=> organization is => organization is

View File

@ -0,0 +1,58 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using OneOf;
namespace Bit.Core.Billing.Services;
public interface IBusinessUnitConverter
{
/// <summary>
/// Finalizes the process of converting the <paramref name="organization"/> to a <see cref="ProviderType.BusinessUnit"/> by
/// saving all the necessary key provided by the client and updating the <paramref name="organization"/>'s subscription to a
/// provider subscription.
/// </summary>
/// <param name="organization">The organization to convert to a business unit.</param>
/// <param name="userId">The ID of the organization member who will be the provider admin.</param>
/// <param name="token">The token sent to the client as part of the <see cref="InitiateConversion"/> process.</param>
/// <param name="providerKey">The encrypted provider key used to enable the <see cref="ProviderUser"/>.</param>
/// <param name="organizationKey">The encrypted organization key used to enable the <see cref="ProviderOrganization"/>.</param>
/// <returns>The provider ID</returns>
Task<Guid> FinalizeConversion(
Organization organization,
Guid userId,
string token,
string providerKey,
string organizationKey);
/// <summary>
/// Begins the process of converting the <paramref name="organization"/> to a <see cref="ProviderType.BusinessUnit"/> by
/// creating all the necessary database entities and sending a setup invitation to the <paramref name="providerAdminEmail"/>.
/// </summary>
/// <param name="organization">The organization to convert to a business unit.</param>
/// <param name="providerAdminEmail">The email address of the organization member who will be the provider admin.</param>
/// <returns>Either the newly created provider ID or a list of validation failures.</returns>
Task<OneOf<Guid, List<string>>> InitiateConversion(
Organization organization,
string providerAdminEmail);
/// <summary>
/// Checks if the <paramref name="organization"/> has a business unit conversion in progress and, if it does, resends the
/// setup invitation to the provider admin.
/// </summary>
/// <param name="organization">The organization to convert to a business unit.</param>
/// <param name="providerAdminEmail">The email address of the organization member who will be the provider admin.</param>
Task ResendConversionInvite(
Organization organization,
string providerAdminEmail);
/// <summary>
/// Checks if the <paramref name="organization"/> has a business unit conversion in progress and, if it does, resets that conversion
/// by deleting all the database entities created as part of <see cref="InitiateConversion"/>.
/// </summary>
/// <param name="organization">The organization to convert to a business unit.</param>
/// <param name="providerAdminEmail">The email address of the organization member who will be the provider admin.</param>
Task ResetConversion(
Organization organization,
string providerAdminEmail);
}

View File

@ -93,7 +93,9 @@ public class OrganizationBillingService(
var isOnSecretsManagerStandalone = await IsOnSecretsManagerStandalone(organization, customer, subscription); var isOnSecretsManagerStandalone = await IsOnSecretsManagerStandalone(organization, customer, subscription);
var invoice = await stripeAdapter.InvoiceGetAsync(subscription.LatestInvoiceId, new InvoiceGetOptions()); var invoice = !string.IsNullOrEmpty(subscription.LatestInvoiceId)
? await stripeAdapter.InvoiceGetAsync(subscription.LatestInvoiceId, new InvoiceGetOptions())
: null;
return new OrganizationMetadata( return new OrganizationMetadata(
isEligibleForSelfHost, isEligibleForSelfHost,

View File

@ -141,13 +141,13 @@ public static class FeatureFlagKeys
/* Billing Team */ /* Billing Team */
public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email"; public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email";
public const string TrialPayment = "PM-8163-trial-payment"; public const string TrialPayment = "PM-8163-trial-payment";
public const string ResellerManagedOrgAlert = "PM-15814-alert-owners-of-reseller-managed-orgs";
public const string UsePricingService = "use-pricing-service"; public const string UsePricingService = "use-pricing-service";
public const string P15179_AddExistingOrgsFromProviderPortal = "pm-15179-add-existing-orgs-from-provider-portal"; public const string P15179_AddExistingOrgsFromProviderPortal = "pm-15179-add-existing-orgs-from-provider-portal";
public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features"; public const string PM12276Breadcrumbing = "pm-12276-breadcrumbing-for-business-features";
public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method"; public const string PM18794_ProviderPaymentMethod = "pm-18794-provider-payment-method";
public const string PM19147_AutomaticTaxImprovements = "pm-19147-automatic-tax-improvements"; public const string PM19147_AutomaticTaxImprovements = "pm-19147-automatic-tax-improvements";
public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates";
public const string PM18770_EnableOrganizationBusinessUnitConversion = "pm-18770-enable-organization-business-unit-conversion";
/* Key Management Team */ /* Key Management Team */
public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair";
@ -156,6 +156,7 @@ public static class FeatureFlagKeys
public const string Argon2Default = "argon2-default"; public const string Argon2Default = "argon2-default";
public const string UserkeyRotationV2 = "userkey-rotation-v2"; public const string UserkeyRotationV2 = "userkey-rotation-v2";
public const string SSHKeyItemVaultItem = "ssh-key-vault-item"; public const string SSHKeyItemVaultItem = "ssh-key-vault-item";
public const string PM17987_BlockType0 = "pm-17987-block-type-0";
/* Mobile Team */ /* Mobile Team */
public const string NativeCarouselFlow = "native-carousel-flow"; public const string NativeCarouselFlow = "native-carousel-flow";
@ -169,14 +170,16 @@ public static class FeatureFlagKeys
public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication"; public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication";
public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync"; public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync";
public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias"; public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias";
public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias"; public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias";
public const string EnablePMFlightRecorder = "enable-pm-flight-recorder";
public const string MobileErrorReporting = "mobile-error-reporting";
/* Platform Team */ /* Platform Team */
public const string PersistPopupView = "persist-popup-view"; public const string PersistPopupView = "persist-popup-view";
public const string StorageReseedRefactor = "storage-reseed-refactor"; public const string StorageReseedRefactor = "storage-reseed-refactor";
public const string WebPush = "web-push"; public const string WebPush = "web-push";
public const string RecordInstallationLastActivityDate = "installation-last-activity-date"; public const string RecordInstallationLastActivityDate = "installation-last-activity-date";
public const string IpcChannelFramework = "ipc-channel-framework";
/* Tools Team */ /* Tools Team */
public const string ItemShare = "item-share"; public const string ItemShare = "item-share";
@ -196,6 +199,7 @@ public static class FeatureFlagKeys
public const string SecurityTasks = "security-tasks"; public const string SecurityTasks = "security-tasks";
public const string CipherKeyEncryption = "cipher-key-encryption"; public const string CipherKeyEncryption = "cipher-key-encryption";
public const string DesktopCipherForms = "pm-18520-desktop-cipher-forms"; public const string DesktopCipherForms = "pm-18520-desktop-cipher-forms";
public const string PM19941MigrateCipherDomainToSdk = "pm-19941-migrate-cipher-domain-to-sdk";
public static List<string> GetAllKeys() public static List<string> GetAllKeys()
{ {

View File

@ -0,0 +1,19 @@
{{#>FullHtmlLayout}}
<table width="100%" cellpadding="0" cellspacing="0" style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
<tr style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
<td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none; text-align: left;" valign="top" align="center">
You have been invited to set up a new Business Unit Portal within Bitwarden.
<br style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;" />
<br style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;" />
</td>
</tr>
<tr style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
<td class="content-block" style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; color: #333; line-height: 25px; margin: 0; -webkit-font-smoothing: antialiased; padding: 0 0 10px; -webkit-text-size-adjust: none; text-align: center;" valign="top" align="center">
<a href="{{{Url}}}" clicktracking=off target="_blank" style="color: #ffffff; text-decoration: none; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; background-color: #175DDC; border-color: #175DDC; border-style: solid; border-width: 10px 20px; margin: 0; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box; font-size: 16px; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;">
Set Up Business Unit Portal Now
</a>
<br style="margin: 0; box-sizing: border-box; color: #333; line-height: 25px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none;" />
</td>
</tr>
</table>
{{/FullHtmlLayout}}

View File

@ -0,0 +1,5 @@
{{#>BasicTextLayout}}
You have been invited to set up a new Business Unit Portal within Bitwarden. To continue, click the following link:
{{{Url}}}
{{/BasicTextLayout}}

View File

@ -0,0 +1,11 @@
namespace Bit.Core.Models.Mail.Billing;
public class BusinessUnitConversionInviteModel : BaseMailModel
{
public string OrganizationId { get; set; }
public string Email { get; set; }
public string Token { get; set; }
public string Url =>
$"{WebVaultUrl}/providers/setup-business-unit?organizationId={OrganizationId}&email={Email}&token={Token}";
}

View File

@ -32,20 +32,22 @@ public class NotificationHubPushNotificationService : IPushNotificationService
private readonly INotificationHubPool _notificationHubPool; private readonly INotificationHubPool _notificationHubPool;
private readonly ILogger _logger; private readonly ILogger _logger;
private readonly IGlobalSettings _globalSettings; private readonly IGlobalSettings _globalSettings;
private readonly TimeProvider _timeProvider;
public NotificationHubPushNotificationService( public NotificationHubPushNotificationService(
IInstallationDeviceRepository installationDeviceRepository, IInstallationDeviceRepository installationDeviceRepository,
INotificationHubPool notificationHubPool, INotificationHubPool notificationHubPool,
IHttpContextAccessor httpContextAccessor, IHttpContextAccessor httpContextAccessor,
ILogger<NotificationHubPushNotificationService> logger, ILogger<NotificationHubPushNotificationService> logger,
IGlobalSettings globalSettings) IGlobalSettings globalSettings,
TimeProvider timeProvider)
{ {
_installationDeviceRepository = installationDeviceRepository; _installationDeviceRepository = installationDeviceRepository;
_httpContextAccessor = httpContextAccessor; _httpContextAccessor = httpContextAccessor;
_notificationHubPool = notificationHubPool; _notificationHubPool = notificationHubPool;
_logger = logger; _logger = logger;
_globalSettings = globalSettings; _globalSettings = globalSettings;
_timeProvider = timeProvider;
if (globalSettings.Installation.Id == Guid.Empty) if (globalSettings.Installation.Id == Guid.Empty)
{ {
logger.LogWarning("Installation ID is not set. Push notifications for installations will not work."); logger.LogWarning("Installation ID is not set. Push notifications for installations will not work.");
@ -152,7 +154,7 @@ public class NotificationHubPushNotificationService : IPushNotificationService
private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false) private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false)
{ {
var message = new UserPushNotification { UserId = userId, Date = DateTime.UtcNow }; var message = new UserPushNotification { UserId = userId, Date = _timeProvider.GetUtcNow().UtcDateTime };
await SendPayloadToUserAsync(userId, type, message, excludeCurrentContext); await SendPayloadToUserAsync(userId, type, message, excludeCurrentContext);
} }

View File

@ -60,6 +60,7 @@ public static class OrganizationServiceCollectionExtensions
services.AddOrganizationDomainCommandsQueries(); services.AddOrganizationDomainCommandsQueries();
services.AddOrganizationSignUpCommands(); services.AddOrganizationSignUpCommands();
services.AddOrganizationDeleteCommands(); services.AddOrganizationDeleteCommands();
services.AddOrganizationUpdateCommands();
services.AddOrganizationEnableCommands(); services.AddOrganizationEnableCommands();
services.AddOrganizationDisableCommands(); services.AddOrganizationDisableCommands();
services.AddOrganizationAuthCommands(); services.AddOrganizationAuthCommands();
@ -77,6 +78,11 @@ public static class OrganizationServiceCollectionExtensions
services.AddScoped<IOrganizationInitiateDeleteCommand, OrganizationInitiateDeleteCommand>(); services.AddScoped<IOrganizationInitiateDeleteCommand, OrganizationInitiateDeleteCommand>();
} }
private static void AddOrganizationUpdateCommands(this IServiceCollection services)
{
services.AddScoped<IOrganizationUpdateKeysCommand, OrganizationUpdateKeysCommand>();
}
private static void AddOrganizationEnableCommands(this IServiceCollection services) => private static void AddOrganizationEnableCommands(this IServiceCollection services) =>
services.AddScoped<IOrganizationEnableCommand, OrganizationEnableCommand>(); services.AddScoped<IOrganizationEnableCommand, OrganizationEnableCommand>();

View File

@ -22,17 +22,19 @@ public class AzureQueuePushNotificationService : IPushNotificationService
private readonly QueueClient _queueClient; private readonly QueueClient _queueClient;
private readonly IHttpContextAccessor _httpContextAccessor; private readonly IHttpContextAccessor _httpContextAccessor;
private readonly IGlobalSettings _globalSettings; private readonly IGlobalSettings _globalSettings;
private readonly TimeProvider _timeProvider;
public AzureQueuePushNotificationService( public AzureQueuePushNotificationService(
[FromKeyedServices("notifications")] QueueClient queueClient, [FromKeyedServices("notifications")] QueueClient queueClient,
IHttpContextAccessor httpContextAccessor, IHttpContextAccessor httpContextAccessor,
IGlobalSettings globalSettings, IGlobalSettings globalSettings,
ILogger<AzureQueuePushNotificationService> logger) ILogger<AzureQueuePushNotificationService> logger,
TimeProvider timeProvider)
{ {
_queueClient = queueClient; _queueClient = queueClient;
_httpContextAccessor = httpContextAccessor; _httpContextAccessor = httpContextAccessor;
_globalSettings = globalSettings; _globalSettings = globalSettings;
_timeProvider = timeProvider;
if (globalSettings.Installation.Id == Guid.Empty) if (globalSettings.Installation.Id == Guid.Empty)
{ {
logger.LogWarning("Installation ID is not set. Push notifications for installations will not work."); logger.LogWarning("Installation ID is not set. Push notifications for installations will not work.");
@ -140,7 +142,7 @@ public class AzureQueuePushNotificationService : IPushNotificationService
private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false) private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false)
{ {
var message = new UserPushNotification { UserId = userId, Date = DateTime.UtcNow }; var message = new UserPushNotification { UserId = userId, Date = _timeProvider.GetUtcNow().UtcDateTime };
await SendMessageAsync(type, message, excludeCurrentContext); await SendMessageAsync(type, message, excludeCurrentContext);
} }

View File

@ -24,12 +24,14 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
{ {
private readonly IGlobalSettings _globalSettings; private readonly IGlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor; private readonly IHttpContextAccessor _httpContextAccessor;
private readonly TimeProvider _timeProvider;
public NotificationsApiPushNotificationService( public NotificationsApiPushNotificationService(
IHttpClientFactory httpFactory, IHttpClientFactory httpFactory,
GlobalSettings globalSettings, GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor, IHttpContextAccessor httpContextAccessor,
ILogger<NotificationsApiPushNotificationService> logger) ILogger<NotificationsApiPushNotificationService> logger,
TimeProvider timeProvider)
: base( : base(
httpFactory, httpFactory,
globalSettings.BaseServiceUri.InternalNotifications, globalSettings.BaseServiceUri.InternalNotifications,
@ -41,6 +43,7 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
{ {
_globalSettings = globalSettings; _globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor; _httpContextAccessor = httpContextAccessor;
_timeProvider = timeProvider;
} }
public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds) public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
@ -148,7 +151,7 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
var message = new UserPushNotification var message = new UserPushNotification
{ {
UserId = userId, UserId = userId,
Date = DateTime.UtcNow Date = _timeProvider.GetUtcNow().UtcDateTime,
}; };
await SendMessageAsync(type, message, excludeCurrentContext); await SendMessageAsync(type, message, excludeCurrentContext);

View File

@ -27,13 +27,15 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
private readonly IDeviceRepository _deviceRepository; private readonly IDeviceRepository _deviceRepository;
private readonly IGlobalSettings _globalSettings; private readonly IGlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor; private readonly IHttpContextAccessor _httpContextAccessor;
private readonly TimeProvider _timeProvider;
public RelayPushNotificationService( public RelayPushNotificationService(
IHttpClientFactory httpFactory, IHttpClientFactory httpFactory,
IDeviceRepository deviceRepository, IDeviceRepository deviceRepository,
GlobalSettings globalSettings, GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor, IHttpContextAccessor httpContextAccessor,
ILogger<RelayPushNotificationService> logger) ILogger<RelayPushNotificationService> logger,
TimeProvider timeProvider)
: base( : base(
httpFactory, httpFactory,
globalSettings.PushRelayBaseUri, globalSettings.PushRelayBaseUri,
@ -46,6 +48,7 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
_deviceRepository = deviceRepository; _deviceRepository = deviceRepository;
_globalSettings = globalSettings; _globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor; _httpContextAccessor = httpContextAccessor;
_timeProvider = timeProvider;
} }
public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds) public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
@ -147,7 +150,7 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false) private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false)
{ {
var message = new UserPushNotification { UserId = userId, Date = DateTime.UtcNow }; var message = new UserPushNotification { UserId = userId, Date = _timeProvider.GetUtcNow().UtcDateTime };
await SendPayloadToUserAsync(userId, type, message, excludeCurrentContext); await SendPayloadToUserAsync(userId, type, message, excludeCurrentContext);
} }

View File

@ -70,6 +70,7 @@ public interface IMailService
Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage); Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage);
Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName); Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName);
Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email); Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email);
Task SendBusinessUnitConversionInviteAsync(Organization organization, string token, string email);
Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email); Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email);
Task SendProviderConfirmedEmailAsync(string providerName, string email); Task SendProviderConfirmedEmailAsync(string providerName, string email);
Task SendProviderUserRemoved(string providerName, string email); Task SendProviderUserRemoved(string providerName, string email);

View File

@ -11,6 +11,7 @@ using Bit.Core.Billing.Models.Mail;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Models.Data.Organizations; using Bit.Core.Models.Data.Organizations;
using Bit.Core.Models.Mail; using Bit.Core.Models.Mail;
using Bit.Core.Models.Mail.Billing;
using Bit.Core.Models.Mail.FamiliesForEnterprise; using Bit.Core.Models.Mail.FamiliesForEnterprise;
using Bit.Core.Models.Mail.Provider; using Bit.Core.Models.Mail.Provider;
using Bit.Core.SecretsManager.Models.Mail; using Bit.Core.SecretsManager.Models.Mail;
@ -949,6 +950,22 @@ public class HandlebarsMailService : IMailService
await _mailDeliveryService.SendEmailAsync(message); await _mailDeliveryService.SendEmailAsync(message);
} }
public async Task SendBusinessUnitConversionInviteAsync(Organization organization, string token, string email)
{
var message = CreateDefaultMessage("Set Up Business Unit", email);
var model = new BusinessUnitConversionInviteModel
{
WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash,
SiteName = _globalSettings.SiteName,
OrganizationId = organization.Id.ToString(),
Email = WebUtility.UrlEncode(email),
Token = WebUtility.UrlEncode(token)
};
await AddMessageContentAsync(message, "Billing.BusinessUnitConversionInvite", model);
message.Category = "BusinessUnitConversionInvite";
await _mailDeliveryService.SendEmailAsync(message);
}
public async Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email) public async Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email)
{ {
var message = CreateDefaultMessage($"Join {providerName}", email); var message = CreateDefaultMessage($"Join {providerName}", email);

View File

@ -212,6 +212,11 @@ public class NoopMailService : IMailService
return Task.FromResult(0); return Task.FromResult(0);
} }
public Task SendBusinessUnitConversionInviteAsync(Organization organization, string token, string email)
{
return Task.FromResult(0);
}
public Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email) public Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email)
{ {
return Task.FromResult(0); return Task.FromResult(0);

View File

@ -711,7 +711,7 @@ public class CipherRepository : Repository<Cipher, Guid>, ICipherRepository
row[creationDateColumn] = cipher.CreationDate; row[creationDateColumn] = cipher.CreationDate;
row[revisionDateColumn] = cipher.RevisionDate; row[revisionDateColumn] = cipher.RevisionDate;
row[deletedDateColumn] = cipher.DeletedDate.HasValue ? (object)cipher.DeletedDate : DBNull.Value; row[deletedDateColumn] = cipher.DeletedDate.HasValue ? (object)cipher.DeletedDate : DBNull.Value;
row[repromptColumn] = cipher.Reprompt; row[repromptColumn] = cipher.Reprompt.HasValue ? cipher.Reprompt.Value : DBNull.Value;
row[keyColummn] = cipher.Key; row[keyColummn] = cipher.Key;
ciphersTable.Rows.Add(row); ciphersTable.Rows.Add(row);

View File

@ -332,7 +332,7 @@ public class OrganizationRepository : Repository<Core.AdminConsole.Entities.Orga
var planTypes = providerType switch var planTypes = providerType switch
{ {
ProviderType.Msp => PlanConstants.EnterprisePlanTypes.Concat(PlanConstants.TeamsPlanTypes), ProviderType.Msp => PlanConstants.EnterprisePlanTypes.Concat(PlanConstants.TeamsPlanTypes),
ProviderType.MultiOrganizationEnterprise => PlanConstants.EnterprisePlanTypes, ProviderType.BusinessUnit => PlanConstants.EnterprisePlanTypes,
_ => [] _ => []
}; };

View File

@ -35,6 +35,7 @@ public class ProviderUserProviderDetailsReadByUserIdStatusQuery : IQuery<Provide
Permissions = x.pu.Permissions, Permissions = x.pu.Permissions,
UseEvents = x.p.UseEvents, UseEvents = x.p.UseEvents,
ProviderStatus = x.p.Status, ProviderStatus = x.p.Status,
ProviderType = x.p.Type
}); });
} }
} }

View File

@ -863,8 +863,30 @@ public class CipherRepository : Repository<Core.Vault.Entities.Cipher, Cipher, G
using (var scope = ServiceScopeFactory.CreateScope()) using (var scope = ServiceScopeFactory.CreateScope())
{ {
var dbContext = GetDatabaseContext(scope); var dbContext = GetDatabaseContext(scope);
var entities = Mapper.Map<List<Cipher>>(ciphers); var ciphersToUpdate = ciphers.ToDictionary(c => c.Id);
await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, entities);
var existingCiphers = await dbContext.Ciphers
.Where(c => c.UserId == userId && ciphersToUpdate.Keys.Contains(c.Id))
.ToDictionaryAsync(c => c.Id);
foreach (var (cipherId, cipher) in ciphersToUpdate)
{
if (!existingCiphers.TryGetValue(cipherId, out var existingCipher))
{
// The Dapper version does not validate that the same amount of items given where updated.
continue;
}
existingCipher.UserId = cipher.UserId;
existingCipher.OrganizationId = cipher.OrganizationId;
existingCipher.Type = cipher.Type;
existingCipher.Data = cipher.Data;
existingCipher.Attachments = cipher.Attachments;
existingCipher.RevisionDate = cipher.RevisionDate;
existingCipher.DeletedDate = cipher.DeletedDate;
existingCipher.Key = cipher.Key;
}
await dbContext.UserBumpAccountRevisionDateAsync(userId); await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync(); await dbContext.SaveChangesAsync();
} }

View File

@ -67,6 +67,7 @@ using Microsoft.Extensions.Caching.Cosmos;
using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
@ -279,6 +280,8 @@ public static class ServiceCollectionExtensions
services.AddSingleton<IMailDeliveryService, NoopMailDeliveryService>(); services.AddSingleton<IMailDeliveryService, NoopMailDeliveryService>();
} }
services.TryAddSingleton(TimeProvider.System);
services.AddSingleton<IPushNotificationService, MultiServicePushNotificationService>(); services.AddSingleton<IPushNotificationService, MultiServicePushNotificationService>();
if (globalSettings.SelfHosted) if (globalSettings.SelfHosted)
{ {

View File

@ -75,25 +75,25 @@ public class ProvidersControllerTests
} }
#endregion #endregion
#region CreateMultiOrganizationEnterpriseAsync #region CreateBusinessUnitAsync
[BitAutoData] [BitAutoData]
[SutProviderCustomize] [SutProviderCustomize]
[Theory] [Theory]
public async Task CreateMultiOrganizationEnterpriseAsync_WithValidModel_CreatesProvider( public async Task CreateBusinessUnitAsync_WithValidModel_CreatesProvider(
CreateMultiOrganizationEnterpriseProviderModel model, CreateBusinessUnitProviderModel model,
SutProvider<ProvidersController> sutProvider) SutProvider<ProvidersController> sutProvider)
{ {
// Arrange // Arrange
// Act // Act
var actual = await sutProvider.Sut.CreateMultiOrganizationEnterprise(model); var actual = await sutProvider.Sut.CreateBusinessUnit(model);
// Assert // Assert
Assert.NotNull(actual); Assert.NotNull(actual);
await sutProvider.GetDependency<ICreateProviderCommand>() await sutProvider.GetDependency<ICreateProviderCommand>()
.Received(Quantity.Exactly(1)) .Received(Quantity.Exactly(1))
.CreateMultiOrganizationEnterpriseAsync( .CreateBusinessUnitAsync(
Arg.Is<Provider>(x => x.Type == ProviderType.MultiOrganizationEnterprise), Arg.Is<Provider>(x => x.Type == ProviderType.BusinessUnit),
model.OwnerEmail, model.OwnerEmail,
Arg.Is<PlanType>(y => y == model.Plan), Arg.Is<PlanType>(y => y == model.Plan),
model.EnterpriseSeatMinimum); model.EnterpriseSeatMinimum);
@ -102,16 +102,16 @@ public class ProvidersControllerTests
[BitAutoData] [BitAutoData]
[SutProviderCustomize] [SutProviderCustomize]
[Theory] [Theory]
public async Task CreateMultiOrganizationEnterpriseAsync_RedirectsToExpectedPage_AfterCreatingProvider( public async Task CreateBusinessUnitAsync_RedirectsToExpectedPage_AfterCreatingProvider(
CreateMultiOrganizationEnterpriseProviderModel model, CreateBusinessUnitProviderModel model,
Guid expectedProviderId, Guid expectedProviderId,
SutProvider<ProvidersController> sutProvider) SutProvider<ProvidersController> sutProvider)
{ {
// Arrange // Arrange
sutProvider.GetDependency<ICreateProviderCommand>() sutProvider.GetDependency<ICreateProviderCommand>()
.When(x => .When(x =>
x.CreateMultiOrganizationEnterpriseAsync( x.CreateBusinessUnitAsync(
Arg.Is<Provider>(y => y.Type == ProviderType.MultiOrganizationEnterprise), Arg.Is<Provider>(y => y.Type == ProviderType.BusinessUnit),
model.OwnerEmail, model.OwnerEmail,
Arg.Is<PlanType>(y => y == model.Plan), Arg.Is<PlanType>(y => y == model.Plan),
model.EnterpriseSeatMinimum)) model.EnterpriseSeatMinimum))
@ -122,7 +122,7 @@ public class ProvidersControllerTests
}); });
// Act // Act
var actual = await sutProvider.Sut.CreateMultiOrganizationEnterprise(model); var actual = await sutProvider.Sut.CreateBusinessUnit(model);
// Assert // Assert
Assert.NotNull(actual); Assert.NotNull(actual);

View File

@ -60,6 +60,7 @@ public class OrganizationsControllerTests : IDisposable
private readonly IOrganizationDeleteCommand _organizationDeleteCommand; private readonly IOrganizationDeleteCommand _organizationDeleteCommand;
private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IPolicyRequirementQuery _policyRequirementQuery;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
private readonly IOrganizationUpdateKeysCommand _organizationUpdateKeysCommand;
private readonly OrganizationsController _sut; private readonly OrganizationsController _sut;
public OrganizationsControllerTests() public OrganizationsControllerTests()
@ -86,6 +87,7 @@ public class OrganizationsControllerTests : IDisposable
_organizationDeleteCommand = Substitute.For<IOrganizationDeleteCommand>(); _organizationDeleteCommand = Substitute.For<IOrganizationDeleteCommand>();
_policyRequirementQuery = Substitute.For<IPolicyRequirementQuery>(); _policyRequirementQuery = Substitute.For<IPolicyRequirementQuery>();
_pricingClient = Substitute.For<IPricingClient>(); _pricingClient = Substitute.For<IPricingClient>();
_organizationUpdateKeysCommand = Substitute.For<IOrganizationUpdateKeysCommand>();
_sut = new OrganizationsController( _sut = new OrganizationsController(
_organizationRepository, _organizationRepository,
@ -109,7 +111,8 @@ public class OrganizationsControllerTests : IDisposable
_cloudOrganizationSignUpCommand, _cloudOrganizationSignUpCommand,
_organizationDeleteCommand, _organizationDeleteCommand,
_policyRequirementQuery, _policyRequirementQuery,
_pricingClient); _pricingClient,
_organizationUpdateKeysCommand);
} }
public void Dispose() public void Dispose()

View File

@ -2,6 +2,7 @@
using Bit.Api.Models.Response; using Bit.Api.Models.Response;
using Bit.Core.Auth.Models.Api.Response; using Bit.Core.Auth.Models.Api.Response;
using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Models.Data;
using Bit.Core.Auth.UserFeatures.DeviceTrust;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums; using Bit.Core.Enums;
@ -19,6 +20,7 @@ public class DevicesControllerTest
private readonly IDeviceRepository _deviceRepositoryMock; private readonly IDeviceRepository _deviceRepositoryMock;
private readonly IDeviceService _deviceServiceMock; private readonly IDeviceService _deviceServiceMock;
private readonly IUserService _userServiceMock; private readonly IUserService _userServiceMock;
private readonly IUntrustDevicesCommand _untrustDevicesCommand;
private readonly IUserRepository _userRepositoryMock; private readonly IUserRepository _userRepositoryMock;
private readonly ICurrentContext _currentContextMock; private readonly ICurrentContext _currentContextMock;
private readonly IGlobalSettings _globalSettingsMock; private readonly IGlobalSettings _globalSettingsMock;
@ -30,6 +32,7 @@ public class DevicesControllerTest
_deviceRepositoryMock = Substitute.For<IDeviceRepository>(); _deviceRepositoryMock = Substitute.For<IDeviceRepository>();
_deviceServiceMock = Substitute.For<IDeviceService>(); _deviceServiceMock = Substitute.For<IDeviceService>();
_userServiceMock = Substitute.For<IUserService>(); _userServiceMock = Substitute.For<IUserService>();
_untrustDevicesCommand = Substitute.For<IUntrustDevicesCommand>();
_userRepositoryMock = Substitute.For<IUserRepository>(); _userRepositoryMock = Substitute.For<IUserRepository>();
_currentContextMock = Substitute.For<ICurrentContext>(); _currentContextMock = Substitute.For<ICurrentContext>();
_loggerMock = Substitute.For<ILogger<DevicesController>>(); _loggerMock = Substitute.For<ILogger<DevicesController>>();
@ -38,6 +41,7 @@ public class DevicesControllerTest
_deviceRepositoryMock, _deviceRepositoryMock,
_deviceServiceMock, _deviceServiceMock,
_userServiceMock, _userServiceMock,
_untrustDevicesCommand,
_userRepositoryMock, _userRepositoryMock,
_currentContextMock, _currentContextMock,
_loggerMock); _loggerMock);

View File

@ -1,7 +1,6 @@
using Bit.Billing.Constants; using Bit.Billing.Constants;
using Bit.Billing.Services; using Bit.Billing.Services;
using Bit.Billing.Services.Implementations; using Bit.Billing.Services.Implementations;
using Bit.Core;
using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces;
using Bit.Core.Billing.Enums; using Bit.Core.Billing.Enums;
@ -31,7 +30,6 @@ public class SubscriptionUpdatedHandlerTests
private readonly IPushNotificationService _pushNotificationService; private readonly IPushNotificationService _pushNotificationService;
private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationRepository _organizationRepository;
private readonly ISchedulerFactory _schedulerFactory; private readonly ISchedulerFactory _schedulerFactory;
private readonly IFeatureService _featureService;
private readonly IOrganizationEnableCommand _organizationEnableCommand; private readonly IOrganizationEnableCommand _organizationEnableCommand;
private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IPricingClient _pricingClient; private readonly IPricingClient _pricingClient;
@ -49,7 +47,6 @@ public class SubscriptionUpdatedHandlerTests
_pushNotificationService = Substitute.For<IPushNotificationService>(); _pushNotificationService = Substitute.For<IPushNotificationService>();
_organizationRepository = Substitute.For<IOrganizationRepository>(); _organizationRepository = Substitute.For<IOrganizationRepository>();
_schedulerFactory = Substitute.For<ISchedulerFactory>(); _schedulerFactory = Substitute.For<ISchedulerFactory>();
_featureService = Substitute.For<IFeatureService>();
_organizationEnableCommand = Substitute.For<IOrganizationEnableCommand>(); _organizationEnableCommand = Substitute.For<IOrganizationEnableCommand>();
_organizationDisableCommand = Substitute.For<IOrganizationDisableCommand>(); _organizationDisableCommand = Substitute.For<IOrganizationDisableCommand>();
_pricingClient = Substitute.For<IPricingClient>(); _pricingClient = Substitute.For<IPricingClient>();
@ -67,7 +64,6 @@ public class SubscriptionUpdatedHandlerTests
_pushNotificationService, _pushNotificationService,
_organizationRepository, _organizationRepository,
_schedulerFactory, _schedulerFactory,
_featureService,
_organizationEnableCommand, _organizationEnableCommand,
_organizationDisableCommand, _organizationDisableCommand,
_pricingClient); _pricingClient);
@ -97,9 +93,6 @@ public class SubscriptionUpdatedHandlerTests
_stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>()) _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any<Dictionary<string, string>>())
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(organizationId, null, null)); .Returns(Tuple.Create<Guid?, Guid?, Guid?>(organizationId, null, null));
_featureService.IsEnabled(FeatureFlagKeys.ResellerManagedOrgAlert)
.Returns(true);
// Act // Act
await _sut.HandleAsync(parsedEvent); await _sut.HandleAsync(parsedEvent);

View File

@ -0,0 +1,75 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Context;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Organizations;
[SutProviderCustomize]
public class OrganizationUpdateKeysCommandTests
{
[Theory, BitAutoData]
public async Task UpdateOrganizationKeysAsync_WithoutManageResetPasswordPermission_ThrowsUnauthorizedException(
Guid orgId, string publicKey, string privateKey, SutProvider<OrganizationUpdateKeysCommand> sutProvider)
{
sutProvider.GetDependency<ICurrentContext>()
.ManageResetPassword(orgId)
.Returns(false);
await Assert.ThrowsAsync<UnauthorizedAccessException>(
() => sutProvider.Sut.UpdateOrganizationKeysAsync(orgId, publicKey, privateKey));
}
[Theory, BitAutoData]
public async Task UpdateOrganizationKeysAsync_WhenKeysAlreadyExist_ThrowsBadRequestException(
Organization organization, string publicKey, string privateKey,
SutProvider<OrganizationUpdateKeysCommand> sutProvider)
{
organization.PublicKey = "existingPublicKey";
organization.PrivateKey = "existingPrivateKey";
sutProvider.GetDependency<ICurrentContext>()
.ManageResetPassword(organization.Id)
.Returns(true);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.UpdateOrganizationKeysAsync(organization.Id, publicKey, privateKey));
Assert.Equal(OrganizationUpdateKeysCommand.OrganizationKeysAlreadyExistErrorMessage, exception.Message);
}
[Theory, BitAutoData]
public async Task UpdateOrganizationKeysAsync_WhenKeysDoNotExist_UpdatesOrganization(
Organization organization, string publicKey, string privateKey,
SutProvider<OrganizationUpdateKeysCommand> sutProvider)
{
organization.PublicKey = null;
organization.PrivateKey = null;
sutProvider.GetDependency<ICurrentContext>()
.ManageResetPassword(organization.Id)
.Returns(true);
sutProvider.GetDependency<IOrganizationRepository>()
.GetByIdAsync(organization.Id)
.Returns(organization);
var result = await sutProvider.Sut.UpdateOrganizationKeysAsync(organization.Id, publicKey, privateKey);
Assert.Equal(publicKey, result.PublicKey);
Assert.Equal(privateKey, result.PrivateKey);
await sutProvider.GetDependency<IOrganizationService>()
.Received(1)
.UpdateAsync(organization);
}
}

View File

@ -814,48 +814,6 @@ public class OrganizationServiceTests
sutProvider.GetDependency<ICurrentContext>().ManageUsers(organization.Id).Returns(true); sutProvider.GetDependency<ICurrentContext>().ManageUsers(organization.Id).Returns(true);
} }
[Theory, BitAutoData]
public async Task UpdateOrganizationKeysAsync_WithoutManageResetPassword_Throws(Guid orgId, string publicKey,
string privateKey, SutProvider<OrganizationService> sutProvider)
{
var currentContext = Substitute.For<ICurrentContext>();
currentContext.ManageResetPassword(orgId).Returns(false);
await Assert.ThrowsAsync<UnauthorizedAccessException>(
() => sutProvider.Sut.UpdateOrganizationKeysAsync(orgId, publicKey, privateKey));
}
[Theory, BitAutoData]
public async Task UpdateOrganizationKeysAsync_KeysAlreadySet_Throws(Organization org, string publicKey,
string privateKey, SutProvider<OrganizationService> sutProvider)
{
var currentContext = sutProvider.GetDependency<ICurrentContext>();
currentContext.ManageResetPassword(org.Id).Returns(true);
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
organizationRepository.GetByIdAsync(org.Id).Returns(org);
var exception = await Assert.ThrowsAsync<BadRequestException>(
() => sutProvider.Sut.UpdateOrganizationKeysAsync(org.Id, publicKey, privateKey));
Assert.Contains("Organization Keys already exist", exception.Message);
}
[Theory, BitAutoData]
public async Task UpdateOrganizationKeysAsync_KeysAlreadySet_Success(Organization org, string publicKey,
string privateKey, SutProvider<OrganizationService> sutProvider)
{
org.PublicKey = null;
org.PrivateKey = null;
var currentContext = sutProvider.GetDependency<ICurrentContext>();
currentContext.ManageResetPassword(org.Id).Returns(true);
var organizationRepository = sutProvider.GetDependency<IOrganizationRepository>();
organizationRepository.GetByIdAsync(org.Id).Returns(org);
await sutProvider.Sut.UpdateOrganizationKeysAsync(org.Id, publicKey, privateKey);
}
[Theory] [Theory]
[PaidOrganizationCustomize(CheckedPlanType = PlanType.EnterpriseAnnually)] [PaidOrganizationCustomize(CheckedPlanType = PlanType.EnterpriseAnnually)]
[BitAutoData("Cannot set max seat autoscaling below seat count", 1, 0, 2, 2)] [BitAutoData("Cannot set max seat autoscaling below seat count", 1, 0, 2, 2)]

View File

@ -0,0 +1,55 @@
using Bit.Core.Auth.UserFeatures.DeviceTrust;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes;
using NSubstitute;
using Xunit;
namespace Bit.Core.Test.Auth.UserFeatures.WebAuthnLogin;
[SutProviderCustomize]
public class UntrustDevicesCommandTests
{
[Theory, BitAutoData]
public async Task SetsKeysToNull(SutProvider<UntrustDevicesCommand> sutProvider, User user)
{
var deviceId = Guid.NewGuid();
// Arrange
sutProvider.GetDependency<IDeviceRepository>()
.GetManyByUserIdAsync(user.Id)
.Returns([new Device
{
Id = deviceId,
EncryptedPrivateKey = "encryptedPrivateKey",
EncryptedPublicKey = "encryptedPublicKey",
EncryptedUserKey = "encryptedUserKey"
}]);
// Act
await sutProvider.Sut.UntrustDevices(user, new List<Guid> { deviceId });
// Assert
await sutProvider.GetDependency<IDeviceRepository>()
.Received()
.UpsertAsync(Arg.Is<Device>(d =>
d.Id == deviceId &&
d.EncryptedPrivateKey == null &&
d.EncryptedPublicKey == null &&
d.EncryptedUserKey == null));
}
[Theory, BitAutoData]
public async Task RejectsWrongUser(SutProvider<UntrustDevicesCommand> sutProvider, User user)
{
var deviceId = Guid.NewGuid();
// Arrange
sutProvider.GetDependency<IDeviceRepository>()
.GetManyByUserIdAsync(user.Id)
.Returns([]);
// Act
await Assert.ThrowsAsync<UnauthorizedAccessException>(async () =>
await sutProvider.Sut.UntrustDevices(user, new List<Guid> { deviceId }));
}
}

View File

@ -1,15 +1,25 @@
#nullable enable #nullable enable
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Nodes;
using Bit.Core.Auth.Entities;
using Bit.Core.Context;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models; using Bit.Core.Models;
using Bit.Core.Models.Data; using Bit.Core.Models.Data;
using Bit.Core.NotificationCenter.Entities; using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Enums;
using Bit.Core.NotificationHub; using Bit.Core.NotificationHub;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Test.NotificationCenter.AutoFixture; using Bit.Core.Test.NotificationCenter.AutoFixture;
using Bit.Core.Tools.Entities;
using Bit.Core.Vault.Entities;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Time.Testing;
using NSubstitute; using NSubstitute;
using Xunit; using Xunit;
@ -19,6 +29,10 @@ namespace Bit.Core.Test.NotificationHub;
[NotificationStatusCustomize] [NotificationStatusCustomize]
public class NotificationHubPushNotificationServiceTests public class NotificationHubPushNotificationServiceTests
{ {
private static readonly string _deviceIdentifier = "test_device_identifier";
private static readonly DateTime _now = DateTime.UtcNow;
private static readonly Guid _installationId = Guid.Parse("da73177b-513f-4444-b582-595c890e1022");
[Theory] [Theory]
[BitAutoData] [BitAutoData]
[NotificationCustomize] [NotificationCustomize]
@ -496,6 +510,630 @@ public class NotificationHubPushNotificationServiceTests
.UpsertAsync(Arg.Any<InstallationDeviceEntity>()); .UpsertAsync(Arg.Any<InstallationDeviceEntity>());
} }
[Fact]
public async Task PushSyncCipherCreateAsync_SendExpectedData()
{
var collectionId = Guid.NewGuid();
var userId = Guid.NewGuid();
var cipher = new Cipher
{
Id = Guid.NewGuid(),
UserId = userId,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = cipher.OrganizationId,
["CollectionIds"] = new JsonArray(collectionId),
["RevisionDate"] = cipher.RevisionDate,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncCipherCreateAsync(cipher, [collectionId]),
PushType.SyncCipherCreate,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Fact]
public async Task PushSyncCipherUpdateAsync_SendExpectedData()
{
var collectionId = Guid.NewGuid();
var userId = Guid.NewGuid();
var cipher = new Cipher
{
Id = Guid.NewGuid(),
UserId = userId,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = cipher.OrganizationId,
["CollectionIds"] = new JsonArray(collectionId),
["RevisionDate"] = cipher.RevisionDate,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncCipherUpdateAsync(cipher, [collectionId]),
PushType.SyncCipherUpdate,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Fact]
public async Task PushSyncCipherDeleteAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var cipher = new Cipher
{
Id = Guid.NewGuid(),
UserId = userId,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = cipher.OrganizationId,
["CollectionIds"] = null,
["RevisionDate"] = cipher.RevisionDate,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncCipherDeleteAsync(cipher),
PushType.SyncLoginDelete,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Fact]
public async Task PushSyncFolderCreateAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var folder = new Folder
{
Id = Guid.NewGuid(),
UserId = userId,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncFolderCreateAsync(folder),
PushType.SyncFolderCreate,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Fact]
public async Task PushSyncFolderUpdateAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var folder = new Folder
{
Id = Guid.NewGuid(),
UserId = userId,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncFolderUpdateAsync(folder),
PushType.SyncFolderUpdate,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Fact]
public async Task PushSyncSendCreateAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var send = new Send
{
Id = Guid.NewGuid(),
UserId = userId,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSendCreateAsync(send),
PushType.SyncSendCreate,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Fact]
public async Task PushAuthRequestAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var authRequest = new AuthRequest
{
Id = Guid.NewGuid(),
UserId = userId,
};
var expectedPayload = new JsonObject
{
["Id"] = authRequest.Id,
["UserId"] = authRequest.UserId,
};
await VerifyNotificationAsync(
async sut => await sut.PushAuthRequestAsync(authRequest),
PushType.AuthRequest,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Fact]
public async Task PushAuthRequestResponseAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var authRequest = new AuthRequest
{
Id = Guid.NewGuid(),
UserId = userId,
};
var expectedPayload = new JsonObject
{
["Id"] = authRequest.Id,
["UserId"] = authRequest.UserId,
};
await VerifyNotificationAsync(
async sut => await sut.PushAuthRequestResponseAsync(authRequest),
PushType.AuthRequestResponse,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Fact]
public async Task PushSyncSendUpdateAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var send = new Send
{
Id = Guid.NewGuid(),
UserId = userId,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSendUpdateAsync(send),
PushType.SyncSendUpdate,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Fact]
public async Task PushSyncSendDeleteAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var send = new Send
{
Id = Guid.NewGuid(),
UserId = userId,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSendDeleteAsync(send),
PushType.SyncSendDelete,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Fact]
public async Task PushSyncCiphersAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["UserId"] = userId,
["Date"] = _now,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncCiphersAsync(userId),
PushType.SyncCiphers,
expectedPayload,
$"(template:payload_userId:{userId})"
);
}
[Fact]
public async Task PushSyncVaultAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["UserId"] = userId,
["Date"] = _now,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncVaultAsync(userId),
PushType.SyncVault,
expectedPayload,
$"(template:payload_userId:{userId})"
);
}
[Fact]
public async Task PushSyncOrganizationsAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["UserId"] = userId,
["Date"] = _now,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncOrganizationsAsync(userId),
PushType.SyncOrganizations,
expectedPayload,
$"(template:payload_userId:{userId})"
);
}
[Fact]
public async Task PushSyncOrgKeysAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["UserId"] = userId,
["Date"] = _now,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncOrgKeysAsync(userId),
PushType.SyncOrgKeys,
expectedPayload,
$"(template:payload_userId:{userId})"
);
}
[Fact]
public async Task PushSyncSettingsAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["UserId"] = userId,
["Date"] = _now,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSettingsAsync(userId),
PushType.SyncSettings,
expectedPayload,
$"(template:payload_userId:{userId})"
);
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task PushLogOutAsync_SendExpectedData(bool excludeCurrentContext)
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["UserId"] = userId,
["Date"] = _now,
};
var expectedTag = excludeCurrentContext
? $"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
: $"(template:payload_userId:{userId})";
await VerifyNotificationAsync(
async sut => await sut.PushLogOutAsync(userId, excludeCurrentContext),
PushType.LogOut,
expectedPayload,
expectedTag
);
}
[Fact]
public async Task PushSyncFolderDeleteAsync_SendExpectedData()
{
var userId = Guid.NewGuid();
var folder = new Folder
{
Id = Guid.NewGuid(),
UserId = userId,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncFolderDeleteAsync(folder),
PushType.SyncFolderDelete,
expectedPayload,
$"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})"
);
}
[Theory]
[InlineData(true, null, null)]
[InlineData(false, "e8e08ce8-8a26-4a65-913a-ba1d8c478b2f", null)]
[InlineData(false, null, "2f53ee32-edf9-4169-b276-760fe92e03bf")]
public async Task PushNotificationAsync_SendExpectedData(bool global, string? userId, string? organizationId)
{
var notification = new Notification
{
Id = Guid.NewGuid(),
Priority = Priority.High,
Global = global,
ClientType = ClientType.All,
UserId = userId != null ? Guid.Parse(userId) : null,
OrganizationId = organizationId != null ? Guid.Parse(organizationId) : null,
TaskId = Guid.NewGuid(),
Title = "My Title",
Body = "My Body",
CreationDate = DateTime.UtcNow.AddDays(-1),
RevisionDate = DateTime.UtcNow,
};
JsonNode? installationId = global ? _installationId : null;
var expectedPayload = new JsonObject
{
["Id"] = notification.Id,
["Priority"] = 3,
["Global"] = global,
["ClientType"] = 0,
["UserId"] = notification.UserId,
["OrganizationId"] = notification.OrganizationId,
["TaskId"] = notification.TaskId,
["InstallationId"] = installationId,
["Title"] = notification.Title,
["Body"] = notification.Body,
["CreationDate"] = notification.CreationDate,
["RevisionDate"] = notification.RevisionDate,
["ReadDate"] = null,
["DeletedDate"] = null,
};
string expectedTag;
if (global)
{
expectedTag = $"(template:payload && installationId:{_installationId} && !deviceIdentifier:{_deviceIdentifier})";
}
else if (notification.OrganizationId.HasValue)
{
expectedTag = "(template:payload && organizationId:2f53ee32-edf9-4169-b276-760fe92e03bf && !deviceIdentifier:test_device_identifier)";
}
else
{
expectedTag = $"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})";
}
await VerifyNotificationAsync(
async sut => await sut.PushNotificationAsync(notification),
PushType.Notification,
expectedPayload,
expectedTag
);
}
[Theory]
[InlineData(true, null, null)]
[InlineData(false, "e8e08ce8-8a26-4a65-913a-ba1d8c478b2f", null)]
[InlineData(false, null, "2f53ee32-edf9-4169-b276-760fe92e03bf")]
public async Task PushNotificationStatusAsync_SendExpectedData(bool global, string? userId, string? organizationId)
{
var notification = new Notification
{
Id = Guid.NewGuid(),
Priority = Priority.High,
Global = global,
ClientType = ClientType.All,
UserId = userId != null ? Guid.Parse(userId) : null,
OrganizationId = organizationId != null ? Guid.Parse(organizationId) : null,
Title = "My Title",
Body = "My Body",
CreationDate = DateTime.UtcNow.AddDays(-1),
RevisionDate = DateTime.UtcNow,
};
var notificationStatus = new NotificationStatus
{
ReadDate = DateTime.UtcNow.AddDays(-1),
DeletedDate = DateTime.UtcNow,
};
JsonNode? installationId = global ? _installationId : null;
var expectedPayload = new JsonObject
{
["Id"] = notification.Id,
["Priority"] = 3,
["Global"] = global,
["ClientType"] = 0,
["UserId"] = notification.UserId,
["OrganizationId"] = notification.OrganizationId,
["TaskId"] = notification.TaskId,
["InstallationId"] = installationId,
["Title"] = notification.Title,
["Body"] = notification.Body,
["CreationDate"] = notification.CreationDate,
["RevisionDate"] = notification.RevisionDate,
["ReadDate"] = notificationStatus.ReadDate,
["DeletedDate"] = notificationStatus.DeletedDate,
};
string expectedTag;
if (global)
{
expectedTag = $"(template:payload && installationId:{_installationId} && !deviceIdentifier:{_deviceIdentifier})";
}
else if (notification.OrganizationId.HasValue)
{
expectedTag = "(template:payload && organizationId:2f53ee32-edf9-4169-b276-760fe92e03bf && !deviceIdentifier:test_device_identifier)";
}
else
{
expectedTag = $"(template:payload_userId:{userId} && !deviceIdentifier:{_deviceIdentifier})";
}
await VerifyNotificationAsync(
async sut => await sut.PushNotificationStatusAsync(notification, notificationStatus),
PushType.NotificationStatus,
expectedPayload,
expectedTag
);
}
private async Task VerifyNotificationAsync(Func<NotificationHubPushNotificationService, Task> test,
PushType type, JsonNode expectedPayload, string tag)
{
var installationDeviceRepository = Substitute.For<IInstallationDeviceRepository>();
var notificationHubPool = Substitute.For<INotificationHubPool>();
var notificationHubProxy = Substitute.For<INotificationHubProxy>();
notificationHubPool.AllClients
.Returns(notificationHubProxy);
var httpContextAccessor = Substitute.For<IHttpContextAccessor>();
var httpContext = new DefaultHttpContext();
var serviceCollection = new ServiceCollection();
var currentContext = Substitute.For<ICurrentContext>();
currentContext.DeviceIdentifier = _deviceIdentifier;
serviceCollection.AddSingleton(currentContext);
httpContext.RequestServices = serviceCollection.BuildServiceProvider();
httpContextAccessor.HttpContext
.Returns(httpContext);
var globalSettings = new Core.Settings.GlobalSettings();
globalSettings.Installation.Id = _installationId;
var fakeTimeProvider = new FakeTimeProvider();
fakeTimeProvider.SetUtcNow(_now);
var sut = new NotificationHubPushNotificationService(
installationDeviceRepository,
notificationHubPool,
httpContextAccessor,
NullLogger<NotificationHubPushNotificationService>.Instance,
globalSettings,
fakeTimeProvider
);
// Act
await test(sut);
// Assert
var calls = notificationHubProxy.ReceivedCalls();
var methodInfo = typeof(INotificationHubProxy).GetMethod(nameof(INotificationHubProxy.SendTemplateNotificationAsync));
var call = Assert.Single(calls, c => c.GetMethodInfo() == methodInfo);
var arguments = call.GetArguments();
var dictionaryArg = (Dictionary<string, string>)arguments[0]!;
var tagArg = (string)arguments[1]!;
Assert.Equal(2, dictionaryArg.Count);
Assert.True(dictionaryArg.TryGetValue("type", out var typeString));
Assert.True(byte.TryParse(typeString, out var typeByte));
Assert.Equal(type, (PushType)typeByte);
Assert.True(dictionaryArg.TryGetValue("payload", out var payloadString));
var actualPayloadNode = JsonNode.Parse(payloadString);
Assert.True(JsonNode.DeepEquals(expectedPayload, actualPayloadNode));
Assert.Equal(tag, tagArg);
}
private static NotificationPushNotification ToNotificationPushNotification(Notification notification, private static NotificationPushNotification ToNotificationPushNotification(Notification notification,
NotificationStatus? notificationStatus, Guid? installationId) => NotificationStatus? notificationStatus, Guid? installationId) =>
new() new()

View File

@ -1,18 +1,27 @@
#nullable enable #nullable enable
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Nodes;
using Azure.Storage.Queues; using Azure.Storage.Queues;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities;
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Enums; using Bit.Core.Enums;
using Bit.Core.Models; using Bit.Core.Models;
using Bit.Core.NotificationCenter.Entities; using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Enums;
using Bit.Core.Platform.Push.Internal; using Bit.Core.Platform.Push.Internal;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Test.AutoFixture; using Bit.Core.Test.AutoFixture;
using Bit.Core.Test.AutoFixture.CurrentContextFixtures; using Bit.Core.Test.AutoFixture.CurrentContextFixtures;
using Bit.Core.Test.NotificationCenter.AutoFixture; using Bit.Core.Test.NotificationCenter.AutoFixture;
using Bit.Core.Tools.Entities;
using Bit.Core.Vault.Entities;
using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture;
using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.AutoFixture.Attributes;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Time.Testing;
using NSubstitute; using NSubstitute;
using Xunit; using Xunit;
@ -22,6 +31,17 @@ namespace Bit.Core.Test.Platform.Push.Services;
[SutProviderCustomize] [SutProviderCustomize]
public class AzureQueuePushNotificationServiceTests public class AzureQueuePushNotificationServiceTests
{ {
private static readonly Guid _deviceId = Guid.Parse("c4730f80-caaa-4772-97bd-5c0d23a2baa3");
private static readonly string _deviceIdentifier = "test_device_identifier";
private readonly FakeTimeProvider _fakeTimeProvider;
private readonly Core.Settings.GlobalSettings _globalSettings = new();
public AzureQueuePushNotificationServiceTests()
{
_fakeTimeProvider = new();
_fakeTimeProvider.SetUtcNow(DateTime.UtcNow);
}
[Theory] [Theory]
[BitAutoData] [BitAutoData]
[NotificationCustomize] [NotificationCustomize]
@ -112,6 +132,761 @@ public class AzureQueuePushNotificationServiceTests
deviceIdentifier.ToString()))); deviceIdentifier.ToString())));
} }
[Theory]
[InlineData("6a5bbe1b-cf16-49a6-965f-5c2eac56a531", null)]
[InlineData(null, "b9a3fcb4-2447-45c1-aad2-24de43c88c44")]
public async Task PushSyncCipherCreateAsync_SendsExpectedResponse(string? userId, string? organizationId)
{
var collectionId = Guid.NewGuid();
var cipher = new Cipher
{
Id = Guid.NewGuid(),
UserId = userId != null ? Guid.Parse(userId) : null,
OrganizationId = organizationId != null ? Guid.Parse(organizationId) : null,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 1,
["Payload"] = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = cipher.OrganizationId,
["CollectionIds"] = new JsonArray(collectionId),
["RevisionDate"] = cipher.RevisionDate,
},
["ContextId"] = _deviceIdentifier,
};
if (!cipher.UserId.HasValue)
{
expectedPayload["Payload"]!.AsObject().Remove("UserId");
}
if (!cipher.OrganizationId.HasValue)
{
expectedPayload["Payload"]!.AsObject().Remove("OrganizationId");
expectedPayload["Payload"]!.AsObject().Remove("CollectionIds");
}
await VerifyNotificationAsync(
async sut => await sut.PushSyncCipherCreateAsync(cipher, [collectionId]),
expectedPayload
);
}
[Theory]
[InlineData("6a5bbe1b-cf16-49a6-965f-5c2eac56a531", null)]
[InlineData(null, "b9a3fcb4-2447-45c1-aad2-24de43c88c44")]
public async Task PushSyncCipherUpdateAsync_SendsExpectedResponse(string? userId, string? organizationId)
{
var collectionId = Guid.NewGuid();
var cipher = new Cipher
{
Id = Guid.NewGuid(),
UserId = userId != null ? Guid.Parse(userId) : null,
OrganizationId = organizationId != null ? Guid.Parse(organizationId) : null,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 0,
["Payload"] = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = cipher.OrganizationId,
["CollectionIds"] = new JsonArray(collectionId),
["RevisionDate"] = cipher.RevisionDate,
},
["ContextId"] = _deviceIdentifier,
};
if (!cipher.UserId.HasValue)
{
expectedPayload["Payload"]!.AsObject().Remove("UserId");
}
if (!cipher.OrganizationId.HasValue)
{
expectedPayload["Payload"]!.AsObject().Remove("OrganizationId");
expectedPayload["Payload"]!.AsObject().Remove("CollectionIds");
}
await VerifyNotificationAsync(
async sut => await sut.PushSyncCipherUpdateAsync(cipher, [collectionId]),
expectedPayload
);
}
[Fact]
public async Task PushSyncCipherDeleteAsync_SendsExpectedResponse()
{
var collectionId = Guid.NewGuid();
var cipher = new Cipher
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
OrganizationId = null,
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 2,
["Payload"] = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["RevisionDate"] = cipher.RevisionDate,
},
["ContextId"] = _deviceIdentifier,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncCipherDeleteAsync(cipher),
expectedPayload
);
}
[Fact]
public async Task PushSyncFolderCreateAsync_SendsExpectedResponse()
{
var collectionId = Guid.NewGuid();
var folder = new Folder
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 7,
["Payload"] = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
},
["ContextId"] = _deviceIdentifier,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncFolderCreateAsync(folder),
expectedPayload
);
}
[Fact]
public async Task PushSyncFolderUpdateAsync_SendsExpectedResponse()
{
var collectionId = Guid.NewGuid();
var folder = new Folder
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 8,
["Payload"] = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
},
["ContextId"] = _deviceIdentifier,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncFolderUpdateAsync(folder),
expectedPayload
);
}
[Fact]
public async Task PushSyncFolderDeleteAsync_SendsExpectedResponse()
{
var collectionId = Guid.NewGuid();
var folder = new Folder
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 3,
["Payload"] = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
},
["ContextId"] = _deviceIdentifier,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncFolderDeleteAsync(folder),
expectedPayload
);
}
[Fact]
public async Task PushSyncCiphersAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["Type"] = 4,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = _fakeTimeProvider.GetUtcNow().UtcDateTime,
},
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncCiphersAsync(userId),
expectedPayload
);
}
[Fact]
public async Task PushSyncVaultAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["Type"] = 5,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = _fakeTimeProvider.GetUtcNow().UtcDateTime,
},
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncVaultAsync(userId),
expectedPayload
);
}
[Fact]
public async Task PushSyncOrganizationsAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["Type"] = 17,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = _fakeTimeProvider.GetUtcNow().UtcDateTime,
},
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncOrganizationsAsync(userId),
expectedPayload
);
}
[Fact]
public async Task PushSyncOrgKeysAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["Type"] = 6,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = _fakeTimeProvider.GetUtcNow().UtcDateTime,
},
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncOrgKeysAsync(userId),
expectedPayload
);
}
[Fact]
public async Task PushSyncSettingsAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["Type"] = 10,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = _fakeTimeProvider.GetUtcNow().UtcDateTime,
},
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSettingsAsync(userId),
expectedPayload
);
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task PushLogOutAsync_SendsExpectedResponse(bool excludeCurrentContext)
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["Type"] = 11,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = _fakeTimeProvider.GetUtcNow().UtcDateTime,
},
};
if (excludeCurrentContext)
{
expectedPayload["ContextId"] = _deviceIdentifier;
}
await VerifyNotificationAsync(
async sut => await sut.PushLogOutAsync(userId, excludeCurrentContext),
expectedPayload
);
}
[Fact]
public async Task PushSyncSendCreateAsync_SendsExpectedResponse()
{
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 12,
["Payload"] = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
},
["ContextId"] = _deviceIdentifier,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSendCreateAsync(send),
expectedPayload
);
}
[Fact]
public async Task PushSyncSendUpdateAsync_SendsExpectedResponse()
{
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 13,
["Payload"] = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
},
["ContextId"] = _deviceIdentifier,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSendUpdateAsync(send),
expectedPayload
);
}
[Fact]
public async Task PushSyncSendDeleteAsync_SendsExpectedResponse()
{
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 14,
["Payload"] = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
},
["ContextId"] = _deviceIdentifier,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSendDeleteAsync(send),
expectedPayload
);
}
[Fact]
public async Task PushAuthRequestAsync_SendsExpectedResponse()
{
var authRequest = new AuthRequest
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
};
var expectedPayload = new JsonObject
{
["Type"] = 15,
["Payload"] = new JsonObject
{
["Id"] = authRequest.Id,
["UserId"] = authRequest.UserId,
},
["ContextId"] = _deviceIdentifier,
};
await VerifyNotificationAsync(
async sut => await sut.PushAuthRequestAsync(authRequest),
expectedPayload
);
}
[Fact]
public async Task PushAuthRequestResponseAsync_SendsExpectedResponse()
{
var authRequest = new AuthRequest
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
};
var expectedPayload = new JsonObject
{
["Type"] = 16,
["Payload"] = new JsonObject
{
["Id"] = authRequest.Id,
["UserId"] = authRequest.UserId,
},
["ContextId"] = _deviceIdentifier,
};
await VerifyNotificationAsync(
async sut => await sut.PushAuthRequestResponseAsync(authRequest),
expectedPayload
);
}
[Theory]
[InlineData(true, null, null)]
[InlineData(false, "e8e08ce8-8a26-4a65-913a-ba1d8c478b2f", null)]
[InlineData(false, null, "2f53ee32-edf9-4169-b276-760fe92e03bf")]
public async Task PushNotificationAsync_SendsExpectedResponse(bool global, string? userId, string? organizationId)
{
var notification = new Notification
{
Id = Guid.NewGuid(),
Priority = Priority.High,
Global = global,
ClientType = ClientType.All,
UserId = userId != null ? Guid.Parse(userId) : null,
OrganizationId = organizationId != null ? Guid.Parse(organizationId) : null,
Title = "My Title",
Body = "My Body",
CreationDate = DateTime.UtcNow.AddDays(-1),
RevisionDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 20,
["Payload"] = new JsonObject
{
["Id"] = notification.Id,
["Priority"] = 3,
["Global"] = global,
["ClientType"] = 0,
["UserId"] = notification.UserId,
["OrganizationId"] = notification.OrganizationId,
["InstallationId"] = _globalSettings.Installation.Id,
["Title"] = notification.Title,
["Body"] = notification.Body,
["CreationDate"] = notification.CreationDate,
["RevisionDate"] = notification.RevisionDate,
},
["ContextId"] = _deviceIdentifier,
};
if (!global)
{
expectedPayload["Payload"]!.AsObject().Remove("InstallationId");
}
if (!notification.UserId.HasValue)
{
expectedPayload["Payload"]!.AsObject().Remove("UserId");
}
if (!notification.OrganizationId.HasValue)
{
expectedPayload["Payload"]!.AsObject().Remove("OrganizationId");
}
await VerifyNotificationAsync(
async sut => await sut.PushNotificationAsync(notification),
expectedPayload
);
}
[Theory]
[InlineData(true, null, null)]
[InlineData(false, "e8e08ce8-8a26-4a65-913a-ba1d8c478b2f", null)]
[InlineData(false, null, "2f53ee32-edf9-4169-b276-760fe92e03bf")]
public async Task PushNotificationStatusAsync_SendsExpectedResponse(bool global, string? userId, string? organizationId)
{
var notification = new Notification
{
Id = Guid.NewGuid(),
Priority = Priority.High,
Global = global,
ClientType = ClientType.All,
UserId = userId != null ? Guid.Parse(userId) : null,
OrganizationId = organizationId != null ? Guid.Parse(organizationId) : null,
Title = "My Title",
Body = "My Body",
CreationDate = DateTime.UtcNow.AddDays(-1),
RevisionDate = DateTime.UtcNow,
};
var notificationStatus = new NotificationStatus
{
ReadDate = DateTime.UtcNow,
DeletedDate = DateTime.UtcNow,
};
var expectedPayload = new JsonObject
{
["Type"] = 21,
["Payload"] = new JsonObject
{
["Id"] = notification.Id,
["Priority"] = 3,
["Global"] = global,
["ClientType"] = 0,
["UserId"] = notification.UserId,
["OrganizationId"] = notification.OrganizationId,
["InstallationId"] = _globalSettings.Installation.Id,
["Title"] = notification.Title,
["Body"] = notification.Body,
["CreationDate"] = notification.CreationDate,
["RevisionDate"] = notification.RevisionDate,
["ReadDate"] = notificationStatus.ReadDate,
["DeletedDate"] = notificationStatus.DeletedDate,
},
["ContextId"] = _deviceIdentifier,
};
if (!global)
{
expectedPayload["Payload"]!.AsObject().Remove("InstallationId");
}
if (!notification.UserId.HasValue)
{
expectedPayload["Payload"]!.AsObject().Remove("UserId");
}
if (!notification.OrganizationId.HasValue)
{
expectedPayload["Payload"]!.AsObject().Remove("OrganizationId");
}
await VerifyNotificationAsync(
async sut => await sut.PushNotificationStatusAsync(notification, notificationStatus),
expectedPayload
);
}
[Fact]
public async Task PushSyncOrganizationStatusAsync_SendsExpectedResponse()
{
var organization = new Organization
{
Id = Guid.NewGuid(),
Enabled = true,
};
var expectedPayload = new JsonObject
{
["Type"] = 18,
["Payload"] = new JsonObject
{
["OrganizationId"] = organization.Id,
["Enabled"] = organization.Enabled,
},
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncOrganizationStatusAsync(organization),
expectedPayload
);
}
[Fact]
public async Task PushSyncOrganizationCollectionManagementSettingsAsync_SendsExpectedResponse()
{
var organization = new Organization
{
Id = Guid.NewGuid(),
Enabled = true,
LimitCollectionCreation = true,
LimitCollectionDeletion = true,
LimitItemDeletion = true,
};
var expectedPayload = new JsonObject
{
["Type"] = 19,
["Payload"] = new JsonObject
{
["OrganizationId"] = organization.Id,
["LimitCollectionCreation"] = organization.LimitCollectionCreation,
["LimitCollectionDeletion"] = organization.LimitCollectionDeletion,
["LimitItemDeletion"] = organization.LimitItemDeletion,
},
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncOrganizationCollectionManagementSettingsAsync(organization),
expectedPayload
);
}
[Fact]
public async Task PushPendingSecurityTasksAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
var expectedPayload = new JsonObject
{
["Type"] = 22,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = _fakeTimeProvider.GetUtcNow().UtcDateTime,
},
};
await VerifyNotificationAsync(
async sut => await sut.PushPendingSecurityTasksAsync(userId),
expectedPayload
);
}
// [Fact]
// public async Task SendPayloadToInstallationAsync_ThrowsNotImplementedException()
// {
// await Assert.ThrowsAsync<NotImplementedException>(
// async () => await sut.SendPayloadToInstallationAsync("installation_id", PushType.AuthRequest, new {}, null)
// );
// }
// [Fact]
// public async Task SendPayloadToUserAsync_ThrowsNotImplementedException()
// {
// await Assert.ThrowsAsync<NotImplementedException>(
// async () => await _sut.SendPayloadToUserAsync("user_id", PushType.AuthRequest, new {}, null)
// );
// }
// [Fact]
// public async Task SendPayloadToOrganizationAsync_ThrowsNotImplementedException()
// {
// await Assert.ThrowsAsync<NotImplementedException>(
// async () => await _sut.SendPayloadToOrganizationAsync("organization_id", PushType.AuthRequest, new {}, null)
// );
// }
private async Task VerifyNotificationAsync(Func<AzureQueuePushNotificationService, Task> test, JsonNode expectedMessage)
{
var queueClient = Substitute.For<QueueClient>();
var httpContextAccessor = Substitute.For<IHttpContextAccessor>();
var httpContext = new DefaultHttpContext();
var serviceCollection = new ServiceCollection();
var currentContext = Substitute.For<ICurrentContext>();
currentContext.DeviceIdentifier = _deviceIdentifier;
serviceCollection.AddSingleton(currentContext);
httpContext.RequestServices = serviceCollection.BuildServiceProvider();
httpContextAccessor.HttpContext
.Returns(httpContext);
var globalSettings = new Core.Settings.GlobalSettings();
var sut = new AzureQueuePushNotificationService(
queueClient,
httpContextAccessor,
globalSettings,
NullLogger<AzureQueuePushNotificationService>.Instance,
_fakeTimeProvider
);
await test(sut);
// Hoist equality checker outside the expression so that we
// can more easily place a breakpoint
var checkEquality = (string actual) =>
{
var actualNode = JsonNode.Parse(actual);
return JsonNode.DeepEquals(actualNode, expectedMessage);
};
await queueClient
.Received(1)
.SendMessageAsync(Arg.Is<string>((actual) => checkEquality(actual)));
}
private static bool MatchMessage<T>(PushType pushType, string message, IEquatable<T> expectedPayloadEquatable, private static bool MatchMessage<T>(PushType pushType, string message, IEquatable<T> expectedPayloadEquatable,
string contextId) string contextId)
{ {

View File

@ -1,41 +1,385 @@
using Bit.Core.Platform.Push; using System.Text.Json.Nodes;
using Bit.Core.Settings; using Bit.Core.AdminConsole.Entities;
using Microsoft.AspNetCore.Http; using Bit.Core.Auth.Entities;
using Microsoft.Extensions.Logging; using Bit.Core.NotificationCenter.Entities;
using NSubstitute; using Bit.Core.Platform.Push;
using Xunit; using Bit.Core.Tools.Entities;
using Bit.Core.Vault.Entities;
using Microsoft.Extensions.Logging.Abstractions;
namespace Bit.Core.Test.Platform.Push.Services; namespace Bit.Core.Test.Platform.Push.Services;
public class NotificationsApiPushNotificationServiceTests public class NotificationsApiPushNotificationServiceTests : PushTestBase
{ {
private readonly NotificationsApiPushNotificationService _sut;
private readonly IHttpClientFactory _httpFactory;
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
private readonly ILogger<NotificationsApiPushNotificationService> _logger;
public NotificationsApiPushNotificationServiceTests() public NotificationsApiPushNotificationServiceTests()
{ {
_httpFactory = Substitute.For<IHttpClientFactory>(); GlobalSettings.BaseServiceUri.InternalNotifications = "https://localhost:7777";
_globalSettings = new GlobalSettings(); GlobalSettings.BaseServiceUri.InternalIdentity = "https://localhost:8888";
_httpContextAccessor = Substitute.For<IHttpContextAccessor>(); }
_logger = Substitute.For<ILogger<NotificationsApiPushNotificationService>>();
_sut = new NotificationsApiPushNotificationService( protected override string ExpectedClientUrl() => "https://localhost:7777/send";
_httpFactory,
_globalSettings, protected override IPushNotificationService CreateService()
_httpContextAccessor, {
_logger return new NotificationsApiPushNotificationService(
HttpClientFactory,
GlobalSettings,
HttpContextAccessor,
NullLogger<NotificationsApiPushNotificationService>.Instance,
FakeTimeProvider
); );
} }
// Remove this test when we add actual tests. It only proves that protected override JsonNode GetPushSyncCipherCreatePayload(Cipher cipher, Guid collectionId)
// we've properly constructed the system under test.
[Fact(Skip = "Needs additional work")]
public void ServiceExists()
{ {
Assert.NotNull(_sut); return new JsonObject
{
["Type"] = 1,
["Payload"] = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = null,
["CollectionIds"] = new JsonArray(collectionId),
["RevisionDate"] = cipher.RevisionDate,
},
["ContextId"] = DeviceIdentifier,
};
}
protected override JsonNode GetPushSyncCipherUpdatePayload(Cipher cipher, Guid collectionId)
{
return new JsonObject
{
["Type"] = 0,
["Payload"] = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = null,
["CollectionIds"] = new JsonArray(collectionId),
["RevisionDate"] = cipher.RevisionDate,
},
["ContextId"] = DeviceIdentifier,
};
}
protected override JsonNode GetPushSyncCipherDeletePayload(Cipher cipher)
{
return new JsonObject
{
["Type"] = 2,
["Payload"] = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = null,
["CollectionIds"] = null,
["RevisionDate"] = cipher.RevisionDate,
},
["ContextId"] = DeviceIdentifier,
};
}
protected override JsonNode GetPushSyncFolderCreatePayload(Folder folder)
{
return new JsonObject
{
["Type"] = 7,
["Payload"] = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
},
["ContextId"] = DeviceIdentifier,
};
}
protected override JsonNode GetPushSyncFolderUpdatePayload(Folder folder)
{
return new JsonObject
{
["Type"] = 8,
["Payload"] = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
},
["ContextId"] = DeviceIdentifier,
};
}
protected override JsonNode GetPushSyncFolderDeletePayload(Folder folder)
{
return new JsonObject
{
["Type"] = 3,
["Payload"] = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
},
["ContextId"] = DeviceIdentifier,
};
}
protected override JsonNode GetPushSyncCiphersPayload(Guid userId)
{
return new JsonObject
{
["Type"] = 4,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ContextId"] = null,
};
}
protected override JsonNode GetPushSyncVaultPayload(Guid userId)
{
return new JsonObject
{
["Type"] = 5,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ContextId"] = null,
};
}
protected override JsonNode GetPushSyncOrganizationsPayload(Guid userId)
{
return new JsonObject
{
["Type"] = 17,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ContextId"] = null,
};
}
protected override JsonNode GetPushSyncOrgKeysPayload(Guid userId)
{
return new JsonObject
{
["Type"] = 6,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ContextId"] = null,
};
}
protected override JsonNode GetPushSyncSettingsPayload(Guid userId)
{
return new JsonObject
{
["Type"] = 10,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ContextId"] = null,
};
}
protected override JsonNode GetPushLogOutPayload(Guid userId, bool excludeCurrentContext)
{
JsonNode? contextId = excludeCurrentContext ? DeviceIdentifier : null;
return new JsonObject
{
["Type"] = 11,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ContextId"] = contextId,
};
}
protected override JsonNode GetPushSendCreatePayload(Send send)
{
return new JsonObject
{
["Type"] = 12,
["Payload"] = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
},
["ContextId"] = null,
};
}
protected override JsonNode GetPushSendUpdatePayload(Send send)
{
return new JsonObject
{
["Type"] = 13,
["Payload"] = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
},
["ContextId"] = null,
};
}
protected override JsonNode GetPushSendDeletePayload(Send send)
{
return new JsonObject
{
["Type"] = 14,
["Payload"] = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
},
["ContextId"] = null,
};
}
protected override JsonNode GetPushAuthRequestPayload(AuthRequest authRequest)
{
return new JsonObject
{
["Type"] = 15,
["Payload"] = new JsonObject
{
["Id"] = authRequest.Id,
["UserId"] = authRequest.UserId,
},
["ContextId"] = DeviceIdentifier,
};
}
protected override JsonNode GetPushAuthRequestResponsePayload(AuthRequest authRequest)
{
return new JsonObject
{
["Type"] = 16,
["Payload"] = new JsonObject
{
["Id"] = authRequest.Id,
["UserId"] = authRequest.UserId,
},
["ContextId"] = DeviceIdentifier,
};
}
protected override JsonNode GetPushNotificationResponsePayload(Notification notification, Guid? userId, Guid? organizationId)
{
JsonNode? installationId = notification.Global ? GlobalSettings.Installation.Id : null;
return new JsonObject
{
["Type"] = 20,
["Payload"] = new JsonObject
{
["Id"] = notification.Id,
["Priority"] = 3,
["Global"] = notification.Global,
["ClientType"] = 0,
["UserId"] = notification.UserId,
["OrganizationId"] = notification.OrganizationId,
["TaskId"] = notification.TaskId,
["InstallationId"] = installationId,
["Title"] = notification.Title,
["Body"] = notification.Body,
["CreationDate"] = notification.CreationDate,
["RevisionDate"] = notification.RevisionDate,
["ReadDate"] = null,
["DeletedDate"] = null,
},
["ContextId"] = DeviceIdentifier,
};
}
protected override JsonNode GetPushNotificationStatusResponsePayload(Notification notification, NotificationStatus notificationStatus, Guid? userId, Guid? organizationId)
{
JsonNode? installationId = notification.Global ? GlobalSettings.Installation.Id : null;
return new JsonObject
{
["Type"] = 21,
["Payload"] = new JsonObject
{
["Id"] = notification.Id,
["Priority"] = 3,
["Global"] = notification.Global,
["ClientType"] = 0,
["UserId"] = notification.UserId,
["OrganizationId"] = notification.OrganizationId,
["InstallationId"] = installationId,
["TaskId"] = notification.TaskId,
["Title"] = notification.Title,
["Body"] = notification.Body,
["CreationDate"] = notification.CreationDate,
["RevisionDate"] = notification.RevisionDate,
["ReadDate"] = notificationStatus.ReadDate,
["DeletedDate"] = notificationStatus.DeletedDate,
},
["ContextId"] = DeviceIdentifier,
};
}
protected override JsonNode GetPushSyncOrganizationStatusResponsePayload(Organization organization)
{
return new JsonObject
{
["Type"] = 18,
["Payload"] = new JsonObject
{
["OrganizationId"] = organization.Id,
["Enabled"] = organization.Enabled,
},
["ContextId"] = null,
};
}
protected override JsonNode GetPushSyncOrganizationCollectionManagementSettingsResponsePayload(Organization organization)
{
return new JsonObject
{
["Type"] = 19,
["Payload"] = new JsonObject
{
["OrganizationId"] = organization.Id,
["LimitCollectionCreation"] = organization.LimitCollectionCreation,
["LimitCollectionDeletion"] = organization.LimitCollectionDeletion,
["LimitItemDeletion"] = organization.LimitItemDeletion,
},
["ContextId"] = null,
};
}
protected override JsonNode GetPushPendingSecurityTasksResponsePayload(Guid userId)
{
return new JsonObject
{
["Type"] = 22,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ContextId"] = null,
};
} }
} }

View File

@ -0,0 +1,498 @@
using System.IdentityModel.Tokens.Jwt;
using System.Net;
using System.Net.Http.Json;
using System.Text.Json;
using System.Text.Json.Nodes;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities;
using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Enums;
using Bit.Core.Platform.Push;
using Bit.Core.Settings;
using Bit.Core.Tools.Entities;
using Bit.Core.Vault.Entities;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Time.Testing;
using NSubstitute;
using RichardSzalay.MockHttp;
using Xunit;
public abstract class PushTestBase
{
protected static readonly string DeviceIdentifier = "test_device_identifier";
protected readonly MockHttpMessageHandler MockClient = new();
protected readonly MockHttpMessageHandler MockIdentityClient = new();
protected readonly IHttpClientFactory HttpClientFactory;
protected readonly GlobalSettings GlobalSettings;
protected readonly IHttpContextAccessor HttpContextAccessor;
protected readonly FakeTimeProvider FakeTimeProvider;
public PushTestBase()
{
HttpClientFactory = Substitute.For<IHttpClientFactory>();
// Mock HttpClient
HttpClientFactory.CreateClient("client")
.Returns(new HttpClient(MockClient));
HttpClientFactory.CreateClient("identity")
.Returns(new HttpClient(MockIdentityClient));
GlobalSettings = new GlobalSettings();
HttpContextAccessor = Substitute.For<IHttpContextAccessor>();
FakeTimeProvider = new FakeTimeProvider();
FakeTimeProvider.SetUtcNow(DateTimeOffset.UtcNow);
}
protected abstract IPushNotificationService CreateService();
protected abstract string ExpectedClientUrl();
protected abstract JsonNode GetPushSyncCipherCreatePayload(Cipher cipher, Guid collectionId);
protected abstract JsonNode GetPushSyncCipherUpdatePayload(Cipher cipher, Guid collectionId);
protected abstract JsonNode GetPushSyncCipherDeletePayload(Cipher cipher);
protected abstract JsonNode GetPushSyncFolderCreatePayload(Folder folder);
protected abstract JsonNode GetPushSyncFolderUpdatePayload(Folder folder);
protected abstract JsonNode GetPushSyncFolderDeletePayload(Folder folder);
protected abstract JsonNode GetPushSyncCiphersPayload(Guid userId);
protected abstract JsonNode GetPushSyncVaultPayload(Guid userId);
protected abstract JsonNode GetPushSyncOrganizationsPayload(Guid userId);
protected abstract JsonNode GetPushSyncOrgKeysPayload(Guid userId);
protected abstract JsonNode GetPushSyncSettingsPayload(Guid userId);
protected abstract JsonNode GetPushLogOutPayload(Guid userId, bool excludeCurrentContext);
protected abstract JsonNode GetPushSendCreatePayload(Send send);
protected abstract JsonNode GetPushSendUpdatePayload(Send send);
protected abstract JsonNode GetPushSendDeletePayload(Send send);
protected abstract JsonNode GetPushAuthRequestPayload(AuthRequest authRequest);
protected abstract JsonNode GetPushAuthRequestResponsePayload(AuthRequest authRequest);
protected abstract JsonNode GetPushNotificationResponsePayload(Notification notification, Guid? userId, Guid? organizationId);
protected abstract JsonNode GetPushNotificationStatusResponsePayload(Notification notification, NotificationStatus notificationStatus, Guid? userId, Guid? organizationId);
protected abstract JsonNode GetPushSyncOrganizationStatusResponsePayload(Organization organization);
protected abstract JsonNode GetPushSyncOrganizationCollectionManagementSettingsResponsePayload(Organization organization);
protected abstract JsonNode GetPushPendingSecurityTasksResponsePayload(Guid userId);
[Fact]
public async Task PushSyncCipherCreateAsync_SendsExpectedResponse()
{
var collectionId = Guid.NewGuid();
var cipher = new Cipher
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
OrganizationId = null,
RevisionDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncCipherCreateAsync(cipher, [collectionId]),
GetPushSyncCipherCreatePayload(cipher, collectionId)
);
}
[Fact]
public async Task PushSyncCipherUpdateAsync_SendsExpectedResponse()
{
var collectionId = Guid.NewGuid();
var cipher = new Cipher
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
OrganizationId = null,
RevisionDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncCipherUpdateAsync(cipher, [collectionId]),
GetPushSyncCipherUpdatePayload(cipher, collectionId)
);
}
[Fact]
public async Task PushSyncCipherDeleteAsync_SendsExpectedResponse()
{
var collectionId = Guid.NewGuid();
var cipher = new Cipher
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
OrganizationId = null,
RevisionDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncCipherDeleteAsync(cipher),
GetPushSyncCipherDeletePayload(cipher)
);
}
[Fact]
public async Task PushSyncFolderCreateAsync_SendsExpectedResponse()
{
var folder = new Folder
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncFolderCreateAsync(folder),
GetPushSyncFolderCreatePayload(folder)
);
}
[Fact]
public async Task PushSyncFolderUpdateAsync_SendsExpectedResponse()
{
var collectionId = Guid.NewGuid();
var folder = new Folder
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncFolderUpdateAsync(folder),
GetPushSyncFolderUpdatePayload(folder)
);
}
[Fact]
public async Task PushSyncFolderDeleteAsync_SendsExpectedResponse()
{
var collectionId = Guid.NewGuid();
var folder = new Folder
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncFolderDeleteAsync(folder),
GetPushSyncFolderDeletePayload(folder)
);
}
[Fact]
public async Task PushSyncCiphersAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
await VerifyNotificationAsync(
async sut => await sut.PushSyncCiphersAsync(userId),
GetPushSyncCiphersPayload(userId)
);
}
[Fact]
public async Task PushSyncVaultAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
await VerifyNotificationAsync(
async sut => await sut.PushSyncVaultAsync(userId),
GetPushSyncVaultPayload(userId)
);
}
[Fact]
public async Task PushSyncOrganizationsAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
await VerifyNotificationAsync(
async sut => await sut.PushSyncOrganizationsAsync(userId),
GetPushSyncOrganizationsPayload(userId)
);
}
[Fact]
public async Task PushSyncOrgKeysAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
await VerifyNotificationAsync(
async sut => await sut.PushSyncOrgKeysAsync(userId),
GetPushSyncOrgKeysPayload(userId)
);
}
[Fact]
public async Task PushSyncSettingsAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
await VerifyNotificationAsync(
async sut => await sut.PushSyncSettingsAsync(userId),
GetPushSyncSettingsPayload(userId)
);
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task PushLogOutAsync_SendsExpectedResponse(bool excludeCurrentContext)
{
var userId = Guid.NewGuid();
await VerifyNotificationAsync(
async sut => await sut.PushLogOutAsync(userId, excludeCurrentContext),
GetPushLogOutPayload(userId, excludeCurrentContext)
);
}
[Fact]
public async Task PushSyncSendCreateAsync_SendsExpectedResponse()
{
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSendCreateAsync(send),
GetPushSendCreatePayload(send)
);
}
[Fact]
public async Task PushSyncSendUpdateAsync_SendsExpectedResponse()
{
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSendUpdateAsync(send),
GetPushSendUpdatePayload(send)
);
}
[Fact]
public async Task PushSyncSendDeleteAsync_SendsExpectedResponse()
{
var send = new Send
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
RevisionDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncSendDeleteAsync(send),
GetPushSendDeletePayload(send)
);
}
[Fact]
public async Task PushAuthRequestAsync_SendsExpectedResponse()
{
var authRequest = new AuthRequest
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
};
await VerifyNotificationAsync(
async sut => await sut.PushAuthRequestAsync(authRequest),
GetPushAuthRequestPayload(authRequest)
);
}
[Fact]
public async Task PushAuthRequestResponseAsync_SendsExpectedResponse()
{
var authRequest = new AuthRequest
{
Id = Guid.NewGuid(),
UserId = Guid.NewGuid(),
};
await VerifyNotificationAsync(
async sut => await sut.PushAuthRequestResponseAsync(authRequest),
GetPushAuthRequestResponsePayload(authRequest)
);
}
[Theory]
[InlineData(true, null, null)]
[InlineData(false, "e8e08ce8-8a26-4a65-913a-ba1d8c478b2f", null)]
[InlineData(false, null, "2f53ee32-edf9-4169-b276-760fe92e03bf")]
public async Task PushNotificationAsync_SendsExpectedResponse(bool global, string? userId, string? organizationId)
{
var notification = new Notification
{
Id = Guid.NewGuid(),
Priority = Priority.High,
Global = global,
ClientType = ClientType.All,
UserId = userId != null ? Guid.Parse(userId) : null,
OrganizationId = organizationId != null ? Guid.Parse(organizationId) : null,
TaskId = Guid.NewGuid(),
Title = "My Title",
Body = "My Body",
CreationDate = DateTime.UtcNow.AddDays(-1),
RevisionDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushNotificationAsync(notification),
GetPushNotificationResponsePayload(notification, notification.UserId, notification.OrganizationId)
);
}
[Theory]
[InlineData(true, null, null)]
[InlineData(false, "e8e08ce8-8a26-4a65-913a-ba1d8c478b2f", null)]
[InlineData(false, null, "2f53ee32-edf9-4169-b276-760fe92e03bf")]
public async Task PushNotificationStatusAsync_SendsExpectedResponse(bool global, string? userId, string? organizationId)
{
var notification = new Notification
{
Id = Guid.NewGuid(),
Priority = Priority.High,
Global = global,
ClientType = ClientType.All,
UserId = userId != null ? Guid.Parse(userId) : null,
OrganizationId = organizationId != null ? Guid.Parse(organizationId) : null,
TaskId = Guid.NewGuid(),
Title = "My Title",
Body = "My Body",
CreationDate = DateTime.UtcNow.AddDays(-1),
RevisionDate = DateTime.UtcNow,
};
var notificationStatus = new NotificationStatus
{
ReadDate = DateTime.UtcNow,
DeletedDate = DateTime.UtcNow,
};
await VerifyNotificationAsync(
async sut => await sut.PushNotificationStatusAsync(notification, notificationStatus),
GetPushNotificationStatusResponsePayload(notification, notificationStatus, notification.UserId, notification.OrganizationId)
);
}
[Fact]
public async Task PushSyncOrganizationStatusAsync_SendsExpectedResponse()
{
var organization = new Organization
{
Id = Guid.NewGuid(),
Enabled = true,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncOrganizationStatusAsync(organization),
GetPushSyncOrganizationStatusResponsePayload(organization)
);
}
[Fact]
public async Task PushSyncOrganizationCollectionManagementSettingsAsync_SendsExpectedResponse()
{
var organization = new Organization
{
Id = Guid.NewGuid(),
Enabled = true,
LimitCollectionCreation = true,
LimitCollectionDeletion = true,
LimitItemDeletion = true,
};
await VerifyNotificationAsync(
async sut => await sut.PushSyncOrganizationCollectionManagementSettingsAsync(organization),
GetPushSyncOrganizationCollectionManagementSettingsResponsePayload(organization)
);
}
[Fact]
public async Task PushPendingSecurityTasksAsync_SendsExpectedResponse()
{
var userId = Guid.NewGuid();
await VerifyNotificationAsync(
async sut => await sut.PushPendingSecurityTasksAsync(userId),
GetPushPendingSecurityTasksResponsePayload(userId)
);
}
private async Task VerifyNotificationAsync(
Func<IPushNotificationService, Task> test,
JsonNode expectedRequestBody
)
{
var httpContext = new DefaultHttpContext();
var serviceCollection = new ServiceCollection();
var currentContext = Substitute.For<ICurrentContext>();
currentContext.DeviceIdentifier = DeviceIdentifier;
serviceCollection.AddSingleton(currentContext);
httpContext.RequestServices = serviceCollection.BuildServiceProvider();
HttpContextAccessor.HttpContext
.Returns(httpContext);
var connectTokenRequest = MockIdentityClient
.Expect(HttpMethod.Post, "https://localhost:8888/connect/token")
.Respond(HttpStatusCode.OK, JsonContent.Create(new
{
access_token = CreateAccessToken(DateTime.UtcNow.AddDays(1)),
}));
JsonNode actualNode = null;
var clientRequest = MockClient
.Expect(HttpMethod.Post, ExpectedClientUrl())
.With(request =>
{
if (request.Content is not JsonContent jsonContent)
{
return false;
}
// TODO: What options?
var actualString = JsonSerializer.Serialize(jsonContent.Value);
actualNode = JsonNode.Parse(actualString);
return JsonNode.DeepEquals(actualNode, expectedRequestBody);
})
.Respond(HttpStatusCode.OK);
await test(CreateService());
Assert.NotNull(actualNode);
Assert.Equal(expectedRequestBody, actualNode, EqualityComparer<JsonNode>.Create(JsonNode.DeepEquals));
Assert.Equal(1, MockClient.GetMatchCount(clientRequest));
}
protected static string CreateAccessToken(DateTime expirationTime)
{
var tokenHandler = new JwtSecurityTokenHandler();
var token = new JwtSecurityToken(expires: expirationTime);
return tokenHandler.WriteToken(token);
}
}

View File

@ -1,45 +1,541 @@
using Bit.Core.Platform.Push.Internal; #nullable enable
using System.Text.Json.Nodes;
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Auth.Entities;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities;
using Bit.Core.Platform.Push.Internal;
using Bit.Core.Repositories; using Bit.Core.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;
using Microsoft.AspNetCore.Http; using Bit.Core.Tools.Entities;
using Microsoft.Extensions.Logging; using Bit.Core.Vault.Entities;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Time.Testing;
using NSubstitute; using NSubstitute;
using Xunit; using Xunit;
namespace Bit.Core.Test.Platform.Push.Services; namespace Bit.Core.Test.Platform.Push.Services;
public class RelayPushNotificationServiceTests public class RelayPushNotificationServiceTests : PushTestBase
{ {
private readonly RelayPushNotificationService _sut; private static readonly Guid _deviceId = Guid.Parse("c4730f80-caaa-4772-97bd-5c0d23a2baa3");
private readonly IHttpClientFactory _httpFactory;
private readonly IDeviceRepository _deviceRepository; private readonly IDeviceRepository _deviceRepository;
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
private readonly ILogger<RelayPushNotificationService> _logger;
public RelayPushNotificationServiceTests() public RelayPushNotificationServiceTests()
{ {
_httpFactory = Substitute.For<IHttpClientFactory>();
_deviceRepository = Substitute.For<IDeviceRepository>(); _deviceRepository = Substitute.For<IDeviceRepository>();
_globalSettings = new GlobalSettings();
_httpContextAccessor = Substitute.For<IHttpContextAccessor>();
_logger = Substitute.For<ILogger<RelayPushNotificationService>>();
_sut = new RelayPushNotificationService( _deviceRepository.GetByIdentifierAsync(DeviceIdentifier)
_httpFactory, .Returns(new Device
{
Id = _deviceId,
});
GlobalSettings.PushRelayBaseUri = "https://localhost:7777";
GlobalSettings.Installation.Id = Guid.Parse("478c608a-99fd-452a-94f0-af271654e6ee");
GlobalSettings.Installation.IdentityUri = "https://localhost:8888";
}
protected override RelayPushNotificationService CreateService()
{
return new RelayPushNotificationService(
HttpClientFactory,
_deviceRepository, _deviceRepository,
_globalSettings, GlobalSettings,
_httpContextAccessor, HttpContextAccessor,
_logger NullLogger<RelayPushNotificationService>.Instance,
FakeTimeProvider
); );
} }
// Remove this test when we add actual tests. It only proves that protected override string ExpectedClientUrl() => "https://localhost:7777/push/send";
// we've properly constructed the system under test.
[Fact(Skip = "Needs additional work")] [Fact]
public void ServiceExists() public async Task SendPayloadToInstallationAsync_ThrowsNotImplementedException()
{ {
Assert.NotNull(_sut); var sut = CreateService();
await Assert.ThrowsAsync<NotImplementedException>(
async () => await sut.SendPayloadToInstallationAsync("installation_id", PushType.AuthRequest, new { }, null)
);
}
[Fact]
public async Task SendPayloadToUserAsync_ThrowsNotImplementedException()
{
var sut = CreateService();
await Assert.ThrowsAsync<NotImplementedException>(
async () => await sut.SendPayloadToUserAsync("user_id", PushType.AuthRequest, new { }, null)
);
}
[Fact]
public async Task SendPayloadToOrganizationAsync_ThrowsNotImplementedException()
{
var sut = CreateService();
await Assert.ThrowsAsync<NotImplementedException>(
async () => await sut.SendPayloadToOrganizationAsync("organization_id", PushType.AuthRequest, new { }, null)
);
}
protected override JsonNode GetPushSyncCipherCreatePayload(Cipher cipher, Guid collectionIds)
{
return new JsonObject
{
["UserId"] = cipher.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 1,
["Payload"] = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = null,
// Currently CollectionIds are not passed along from the method signature
// to the request body.
["CollectionIds"] = null,
["RevisionDate"] = cipher.RevisionDate,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncCipherUpdatePayload(Cipher cipher, Guid collectionIds)
{
return new JsonObject
{
["UserId"] = cipher.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 0,
["Payload"] = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = null,
// Currently CollectionIds are not passed along from the method signature
// to the request body.
["CollectionIds"] = null,
["RevisionDate"] = cipher.RevisionDate,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncCipherDeletePayload(Cipher cipher)
{
return new JsonObject
{
["UserId"] = cipher.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 2,
["Payload"] = new JsonObject
{
["Id"] = cipher.Id,
["UserId"] = cipher.UserId,
["OrganizationId"] = null,
["CollectionIds"] = null,
["RevisionDate"] = cipher.RevisionDate,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncFolderCreatePayload(Folder folder)
{
return new JsonObject
{
["UserId"] = folder.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 7,
["Payload"] = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncFolderUpdatePayload(Folder folder)
{
return new JsonObject
{
["UserId"] = folder.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 8,
["Payload"] = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncFolderDeletePayload(Folder folder)
{
return new JsonObject
{
["UserId"] = folder.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 3,
["Payload"] = new JsonObject
{
["Id"] = folder.Id,
["UserId"] = folder.UserId,
["RevisionDate"] = folder.RevisionDate,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncCiphersPayload(Guid userId)
{
return new JsonObject
{
["UserId"] = userId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = null,
["Type"] = 4,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncVaultPayload(Guid userId)
{
return new JsonObject
{
["UserId"] = userId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = null,
["Type"] = 5,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncOrganizationsPayload(Guid userId)
{
return new JsonObject
{
["UserId"] = userId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = null,
["Type"] = 17,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncOrgKeysPayload(Guid userId)
{
return new JsonObject
{
["UserId"] = userId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = null,
["Type"] = 6,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncSettingsPayload(Guid userId)
{
return new JsonObject
{
["UserId"] = userId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = null,
["Type"] = 10,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushLogOutPayload(Guid userId, bool excludeCurrentContext)
{
JsonNode? identifier = excludeCurrentContext ? DeviceIdentifier : null;
return new JsonObject
{
["UserId"] = userId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = identifier,
["Type"] = 11,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSendCreatePayload(Send send)
{
return new JsonObject
{
["UserId"] = send.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 12,
["Payload"] = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSendUpdatePayload(Send send)
{
return new JsonObject
{
["UserId"] = send.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 13,
["Payload"] = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSendDeletePayload(Send send)
{
return new JsonObject
{
["UserId"] = send.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 14,
["Payload"] = new JsonObject
{
["Id"] = send.Id,
["UserId"] = send.UserId,
["RevisionDate"] = send.RevisionDate,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushAuthRequestPayload(AuthRequest authRequest)
{
return new JsonObject
{
["UserId"] = authRequest.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 15,
["Payload"] = new JsonObject
{
["Id"] = authRequest.Id,
["UserId"] = authRequest.UserId,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushAuthRequestResponsePayload(AuthRequest authRequest)
{
return new JsonObject
{
["UserId"] = authRequest.UserId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 16,
["Payload"] = new JsonObject
{
["Id"] = authRequest.Id,
["UserId"] = authRequest.UserId,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushNotificationResponsePayload(Notification notification, Guid? userId, Guid? organizationId)
{
JsonNode? installationId = notification.Global ? GlobalSettings.Installation.Id : null;
return new JsonObject
{
["UserId"] = notification.UserId,
["OrganizationId"] = notification.OrganizationId,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 20,
["Payload"] = new JsonObject
{
["Id"] = notification.Id,
["Priority"] = 3,
["Global"] = notification.Global,
["ClientType"] = 0,
["UserId"] = userId,
["OrganizationId"] = organizationId,
["TaskId"] = notification.TaskId,
["InstallationId"] = installationId,
["Title"] = notification.Title,
["Body"] = notification.Body,
["CreationDate"] = notification.CreationDate,
["RevisionDate"] = notification.RevisionDate,
["ReadDate"] = null,
["DeletedDate"] = null,
},
["ClientType"] = 0,
["InstallationId"] = installationId?.DeepClone(),
};
}
protected override JsonNode GetPushNotificationStatusResponsePayload(Notification notification, NotificationStatus notificationStatus, Guid? userId, Guid? organizationId)
{
JsonNode? installationId = notification.Global ? GlobalSettings.Installation.Id : null;
return new JsonObject
{
["UserId"] = notification.UserId,
["OrganizationId"] = notification.OrganizationId,
["DeviceId"] = _deviceId,
["Identifier"] = DeviceIdentifier,
["Type"] = 21,
["Payload"] = new JsonObject
{
["Id"] = notification.Id,
["Priority"] = 3,
["Global"] = notification.Global,
["ClientType"] = 0,
["UserId"] = notification.UserId,
["OrganizationId"] = notification.OrganizationId,
["InstallationId"] = installationId,
["TaskId"] = notification.TaskId,
["Title"] = notification.Title,
["Body"] = notification.Body,
["CreationDate"] = notification.CreationDate,
["RevisionDate"] = notification.RevisionDate,
["ReadDate"] = notificationStatus.ReadDate,
["DeletedDate"] = notificationStatus.DeletedDate,
},
["ClientType"] = 0,
["InstallationId"] = installationId?.DeepClone(),
};
}
protected override JsonNode GetPushSyncOrganizationStatusResponsePayload(Organization organization)
{
return new JsonObject
{
["UserId"] = null,
["OrganizationId"] = organization.Id,
["DeviceId"] = _deviceId,
["Identifier"] = null,
["Type"] = 18,
["Payload"] = new JsonObject
{
["OrganizationId"] = organization.Id,
["Enabled"] = organization.Enabled,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushSyncOrganizationCollectionManagementSettingsResponsePayload(Organization organization)
{
return new JsonObject
{
["UserId"] = null,
["OrganizationId"] = organization.Id,
["DeviceId"] = _deviceId,
["Identifier"] = null,
["Type"] = 19,
["Payload"] = new JsonObject
{
["OrganizationId"] = organization.Id,
["LimitCollectionCreation"] = organization.LimitCollectionCreation,
["LimitCollectionDeletion"] = organization.LimitCollectionDeletion,
["LimitItemDeletion"] = organization.LimitItemDeletion,
},
["ClientType"] = null,
["InstallationId"] = null,
};
}
protected override JsonNode GetPushPendingSecurityTasksResponsePayload(Guid userId)
{
return new JsonObject
{
["UserId"] = userId,
["OrganizationId"] = null,
["DeviceId"] = _deviceId,
["Identifier"] = null,
["Type"] = 22,
["Payload"] = new JsonObject
{
["UserId"] = userId,
["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime,
},
["ClientType"] = null,
["InstallationId"] = null,
};
} }
} }

View File

@ -883,4 +883,33 @@ public class CipherRepositoryTests
Assert.Contains(user2TaskCiphers, t => t.CipherId == manageCipher1.Id && t.TaskId == securityTasks[0].Id); Assert.Contains(user2TaskCiphers, t => t.CipherId == manageCipher1.Id && t.TaskId == securityTasks[0].Id);
Assert.Contains(user2TaskCiphers, t => t.CipherId == manageCipher2.Id && t.TaskId == securityTasks[1].Id); Assert.Contains(user2TaskCiphers, t => t.CipherId == manageCipher2.Id && t.TaskId == securityTasks[1].Id);
} }
[DatabaseTheory, DatabaseData]
public async Task UpdateCiphersAsync_Works(ICipherRepository cipherRepository, IUserRepository userRepository)
{
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{Guid.NewGuid()}@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var cipher1 = await CreatePersonalCipher(user, cipherRepository);
var cipher2 = await CreatePersonalCipher(user, cipherRepository);
cipher1.Type = CipherType.SecureNote;
cipher2.Attachments = "new_attachments";
await cipherRepository.UpdateCiphersAsync(user.Id, [cipher1, cipher2]);
var updatedCipher1 = await cipherRepository.GetByIdAsync(cipher1.Id);
var updatedCipher2 = await cipherRepository.GetByIdAsync(cipher2.Id);
Assert.NotNull(updatedCipher1);
Assert.NotNull(updatedCipher2);
Assert.Equal(CipherType.SecureNote, updatedCipher1.Type);
Assert.Equal("new_attachments", updatedCipher2.Attachments);
}
} }