From 39c560bbdda36ff24455e8e7599e431e8c1b86f0 Mon Sep 17 00:00:00 2001 From: Daniel James Smith <2670567+djsmith85@users.noreply.github.com> Date: Wed, 23 Oct 2024 18:01:23 +0200 Subject: [PATCH 1/5] Add generator-tools-modernization feature flag (#4933) Co-authored-by: Daniel James Smith --- src/Core/Constants.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index ecbe190ccd..f193f7995a 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -148,6 +148,7 @@ public static class FeatureFlagKeys public const string VerifiedSsoDomainEndpoint = "pm-12337-refactor-sso-details-endpoint"; public const string Pm13322AddPolicyDefinitions = "pm-13322-add-policy-definitions"; public const string LimitCollectionCreationDeletionSplit = "pm-10863-limit-collection-creation-deletion-split"; + public const string GeneratorToolsModernization = "generator-tools-modernization"; public static List GetAllKeys() { From a952d106374d838b1b526b3289a8d152240c1904 Mon Sep 17 00:00:00 2001 From: Jonas Hendrickx Date: Wed, 23 Oct 2024 18:10:50 +0200 Subject: [PATCH 2/5] [PM-13447] Add Multi Org Enterprise providers to Admin Console (#4920) --- .../Providers/CreateProviderCommand.cs | 37 ++- .../CreateProviderCommandTests.cs | 49 ++++ .../Controllers/ProvidersController.cs | 106 +++++++- .../Models/CreateMspProviderModel.cs | 45 ++++ ...ultiOrganizationEnterpriseProviderModel.cs | 47 ++++ .../Models/CreateProviderModel.cs | 80 +----- .../Models/CreateResellerProviderModel.cs | 48 ++++ .../Views/Providers/Create.cshtml | 80 ++---- .../Views/Providers/CreateMsp.cshtml | 39 +++ .../CreateMultiOrganizationEnterprise.cshtml | 43 +++ .../Views/Providers/CreateReseller.cshtml | 25 ++ src/Admin/Enums/HtmlHelperExtensions.cs | 19 ++ .../Enums/Provider/ProviderType.cs | 6 +- .../Interfaces/ICreateProviderCommand.cs | 2 + src/Core/Constants.cs | 1 + .../Controllers/ProvidersControllerTests.cs | 251 ++++++++++++++++++ 16 files changed, 717 insertions(+), 161 deletions(-) create mode 100644 src/Admin/AdminConsole/Models/CreateMspProviderModel.cs create mode 100644 src/Admin/AdminConsole/Models/CreateMultiOrganizationEnterpriseProviderModel.cs create mode 100644 src/Admin/AdminConsole/Models/CreateResellerProviderModel.cs create mode 100644 src/Admin/AdminConsole/Views/Providers/CreateMsp.cshtml create mode 100644 src/Admin/AdminConsole/Views/Providers/CreateMultiOrganizationEnterprise.cshtml create mode 100644 src/Admin/AdminConsole/Views/Providers/CreateReseller.cshtml create mode 100644 src/Admin/Enums/HtmlHelperExtensions.cs create mode 100644 test/Admin.Test/AdminConsole/Controllers/ProvidersControllerTests.cs diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/CreateProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/CreateProviderCommand.cs index 09157d72c5..d192073d4d 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/CreateProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/CreateProviderCommand.cs @@ -40,6 +40,32 @@ public class CreateProviderCommand : ICreateProviderCommand } public async Task CreateMspAsync(Provider provider, string ownerEmail, int teamsMinimumSeats, int enterpriseMinimumSeats) + { + var providerPlans = new List + { + CreateProviderPlan(provider.Id, PlanType.TeamsMonthly, teamsMinimumSeats), + CreateProviderPlan(provider.Id, PlanType.EnterpriseMonthly, enterpriseMinimumSeats) + }; + + await CreateProviderAsync(provider, ownerEmail, providerPlans); + } + + public async Task CreateResellerAsync(Provider provider) + { + await ProviderRepositoryCreateAsync(provider, ProviderStatusType.Created); + } + + public async Task CreateMultiOrganizationEnterpriseAsync(Provider provider, string ownerEmail, PlanType plan, int minimumSeats) + { + var providerPlans = new List + { + CreateProviderPlan(provider.Id, plan, minimumSeats) + }; + + await CreateProviderAsync(provider, ownerEmail, providerPlans); + } + + private async Task CreateProviderAsync(Provider provider, string ownerEmail, List providerPlans) { var owner = await _userRepository.GetByEmailAsync(ownerEmail); if (owner == null) @@ -66,12 +92,6 @@ public class CreateProviderCommand : ICreateProviderCommand if (isConsolidatedBillingEnabled) { - var providerPlans = new List - { - CreateProviderPlan(provider.Id, PlanType.TeamsMonthly, teamsMinimumSeats), - CreateProviderPlan(provider.Id, PlanType.EnterpriseMonthly, enterpriseMinimumSeats) - }; - foreach (var providerPlan in providerPlans) { await _providerPlanRepository.CreateAsync(providerPlan); @@ -82,11 +102,6 @@ public class CreateProviderCommand : ICreateProviderCommand await _providerService.SendProviderSetupInviteEmailAsync(provider, owner.Email); } - public async Task CreateResellerAsync(Provider provider) - { - await ProviderRepositoryCreateAsync(provider, ProviderStatusType.Created); - } - private async Task ProviderRepositoryCreateAsync(Provider provider, ProviderStatusType status) { provider.Status = status; diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/CreateProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/CreateProviderCommandTests.cs index 787d5a17b3..e354e44173 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/CreateProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/CreateProviderCommandTests.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -19,23 +20,30 @@ public class CreateProviderCommandTests [Theory, BitAutoData] public async Task CreateMspAsync_UserIdIsInvalid_Throws(Provider provider, SutProvider sutProvider) { + // Arrange provider.Type = ProviderType.Msp; + // Act var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.CreateMspAsync(provider, default, default, default)); + + // Assert Assert.Contains("Invalid owner.", exception.Message); } [Theory, BitAutoData] public async Task CreateMspAsync_Success(Provider provider, User user, SutProvider sutProvider) { + // Arrange provider.Type = ProviderType.Msp; var userRepository = sutProvider.GetDependency(); userRepository.GetByEmailAsync(user.Email).Returns(user); + // Act await sutProvider.Sut.CreateMspAsync(provider, user.Email, default, default); + // Assert await sutProvider.GetDependency().ReceivedWithAnyArgs().CreateAsync(default); await sutProvider.GetDependency().Received(1).SendProviderSetupInviteEmailAsync(provider, user.Email); } @@ -43,11 +51,52 @@ public class CreateProviderCommandTests [Theory, BitAutoData] public async Task CreateResellerAsync_Success(Provider provider, SutProvider sutProvider) { + // Arrange provider.Type = ProviderType.Reseller; + // Act await sutProvider.Sut.CreateResellerAsync(provider); + // Assert await sutProvider.GetDependency().ReceivedWithAnyArgs().CreateAsync(default); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().SendProviderSetupInviteEmailAsync(default, default); } + + [Theory, BitAutoData] + public async Task CreateMultiOrganizationEnterpriseAsync_Success( + Provider provider, + User user, + PlanType plan, + int minimumSeats, + SutProvider sutProvider) + { + // Arrange + provider.Type = ProviderType.MultiOrganizationEnterprise; + + var userRepository = sutProvider.GetDependency(); + userRepository.GetByEmailAsync(user.Email).Returns(user); + + // Act + await sutProvider.Sut.CreateMultiOrganizationEnterpriseAsync(provider, user.Email, plan, minimumSeats); + + // Assert + await sutProvider.GetDependency().ReceivedWithAnyArgs().CreateAsync(provider); + await sutProvider.GetDependency().Received(1).SendProviderSetupInviteEmailAsync(provider, user.Email); + } + + [Theory, BitAutoData] + public async Task CreateMultiOrganizationEnterpriseAsync_UserIdIsInvalid_Throws( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.Type = ProviderType.Msp; + + // Act + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateMultiOrganizationEnterpriseAsync(provider, default, default, default)); + + // Assert + Assert.Contains("Invalid owner.", exception.Message); + } } diff --git a/src/Admin/AdminConsole/Controllers/ProvidersController.cs b/src/Admin/AdminConsole/Controllers/ProvidersController.cs index 12e2c4d439..a7c49b214b 100644 --- a/src/Admin/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Admin/AdminConsole/Controllers/ProvidersController.cs @@ -107,9 +107,15 @@ public class ProvidersController : Controller }); } - public IActionResult Create(int teamsMinimumSeats, int enterpriseMinimumSeats, string ownerEmail = null) + public IActionResult Create() { - return View(new CreateProviderModel + return View(new CreateProviderModel()); + } + + [HttpGet("providers/create/msp")] + public IActionResult CreateMsp(int teamsMinimumSeats, int enterpriseMinimumSeats, string ownerEmail = null) + { + return View(new CreateMspProviderModel { OwnerEmail = ownerEmail, TeamsMonthlySeatMinimum = teamsMinimumSeats, @@ -117,10 +123,50 @@ public class ProvidersController : Controller }); } + [HttpGet("providers/create/reseller")] + public IActionResult CreateReseller() + { + return View(new CreateResellerProviderModel()); + } + + [HttpGet("providers/create/multi-organization-enterprise")] + public IActionResult CreateMultiOrganizationEnterprise(int enterpriseMinimumSeats, string ownerEmail = null) + { + if (!_featureService.IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises)) + { + return RedirectToAction("Create"); + } + + return View(new CreateMultiOrganizationEnterpriseProviderModel + { + OwnerEmail = ownerEmail, + EnterpriseSeatMinimum = enterpriseMinimumSeats + }); + } + [HttpPost] [ValidateAntiForgeryToken] [RequirePermission(Permission.Provider_Create)] - public async Task Create(CreateProviderModel model) + public IActionResult Create(CreateProviderModel model) + { + if (!ModelState.IsValid) + { + return View(model); + } + + return model.Type switch + { + ProviderType.Msp => RedirectToAction("CreateMsp"), + ProviderType.Reseller => RedirectToAction("CreateReseller"), + ProviderType.MultiOrganizationEnterprise => RedirectToAction("CreateMultiOrganizationEnterprise"), + _ => View(model) + }; + } + + [HttpPost("providers/create/msp")] + [ValidateAntiForgeryToken] + [RequirePermission(Permission.Provider_Create)] + public async Task CreateMsp(CreateMspProviderModel model) { if (!ModelState.IsValid) { @@ -128,19 +174,51 @@ public class ProvidersController : Controller } var provider = model.ToProvider(); - switch (provider.Type) + + await _createProviderCommand.CreateMspAsync( + provider, + model.OwnerEmail, + model.TeamsMonthlySeatMinimum, + model.EnterpriseMonthlySeatMinimum); + + return RedirectToAction("Edit", new { id = provider.Id }); + } + + [HttpPost("providers/create/reseller")] + [ValidateAntiForgeryToken] + [RequirePermission(Permission.Provider_Create)] + public async Task CreateReseller(CreateResellerProviderModel model) + { + if (!ModelState.IsValid) { - case ProviderType.Msp: - await _createProviderCommand.CreateMspAsync( - provider, - model.OwnerEmail, - model.TeamsMonthlySeatMinimum, - model.EnterpriseMonthlySeatMinimum); - break; - case ProviderType.Reseller: - await _createProviderCommand.CreateResellerAsync(provider); - break; + return View(model); } + var provider = model.ToProvider(); + await _createProviderCommand.CreateResellerAsync(provider); + + return RedirectToAction("Edit", new { id = provider.Id }); + } + + [HttpPost("providers/create/multi-organization-enterprise")] + [ValidateAntiForgeryToken] + [RequirePermission(Permission.Provider_Create)] + public async Task CreateMultiOrganizationEnterprise(CreateMultiOrganizationEnterpriseProviderModel model) + { + if (!ModelState.IsValid) + { + return View(model); + } + var provider = model.ToProvider(); + + if (!_featureService.IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises)) + { + return RedirectToAction("Create"); + } + await _createProviderCommand.CreateMultiOrganizationEnterpriseAsync( + provider, + model.OwnerEmail, + model.Plan.Value, + model.EnterpriseSeatMinimum); return RedirectToAction("Edit", new { id = provider.Id }); } diff --git a/src/Admin/AdminConsole/Models/CreateMspProviderModel.cs b/src/Admin/AdminConsole/Models/CreateMspProviderModel.cs new file mode 100644 index 0000000000..f48cf21767 --- /dev/null +++ b/src/Admin/AdminConsole/Models/CreateMspProviderModel.cs @@ -0,0 +1,45 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.SharedWeb.Utilities; + +namespace Bit.Admin.AdminConsole.Models; + +public class CreateMspProviderModel : IValidatableObject +{ + [Display(Name = "Owner Email")] + public string OwnerEmail { get; set; } + + [Display(Name = "Teams (Monthly) Seat Minimum")] + public int TeamsMonthlySeatMinimum { get; set; } + + [Display(Name = "Enterprise (Monthly) Seat Minimum")] + public int EnterpriseMonthlySeatMinimum { get; set; } + + public virtual Provider ToProvider() + { + return new Provider + { + Type = ProviderType.Msp + }; + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (string.IsNullOrWhiteSpace(OwnerEmail)) + { + var ownerEmailDisplayName = nameof(OwnerEmail).GetDisplayAttribute()?.GetName() ?? nameof(OwnerEmail); + yield return new ValidationResult($"The {ownerEmailDisplayName} field is required."); + } + if (TeamsMonthlySeatMinimum < 0) + { + var teamsMinimumSeatsDisplayName = nameof(TeamsMonthlySeatMinimum).GetDisplayAttribute()?.GetName() ?? nameof(TeamsMonthlySeatMinimum); + yield return new ValidationResult($"The {teamsMinimumSeatsDisplayName} field can not be negative."); + } + if (EnterpriseMonthlySeatMinimum < 0) + { + var enterpriseMinimumSeatsDisplayName = nameof(EnterpriseMonthlySeatMinimum).GetDisplayAttribute()?.GetName() ?? nameof(EnterpriseMonthlySeatMinimum); + yield return new ValidationResult($"The {enterpriseMinimumSeatsDisplayName} field can not be negative."); + } + } +} diff --git a/src/Admin/AdminConsole/Models/CreateMultiOrganizationEnterpriseProviderModel.cs b/src/Admin/AdminConsole/Models/CreateMultiOrganizationEnterpriseProviderModel.cs new file mode 100644 index 0000000000..ef7210a9ef --- /dev/null +++ b/src/Admin/AdminConsole/Models/CreateMultiOrganizationEnterpriseProviderModel.cs @@ -0,0 +1,47 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Enums; +using Bit.SharedWeb.Utilities; + +namespace Bit.Admin.AdminConsole.Models; + +public class CreateMultiOrganizationEnterpriseProviderModel : IValidatableObject +{ + [Display(Name = "Owner Email")] + public string OwnerEmail { get; set; } + + [Display(Name = "Enterprise Seat Minimum")] + public int EnterpriseSeatMinimum { get; set; } + + [Display(Name = "Plan")] + [Required] + public PlanType? Plan { get; set; } + + public virtual Provider ToProvider() + { + return new Provider + { + Type = ProviderType.MultiOrganizationEnterprise + }; + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (string.IsNullOrWhiteSpace(OwnerEmail)) + { + var ownerEmailDisplayName = nameof(OwnerEmail).GetDisplayAttribute()?.GetName() ?? nameof(OwnerEmail); + yield return new ValidationResult($"The {ownerEmailDisplayName} field is required."); + } + if (EnterpriseSeatMinimum < 0) + { + var enterpriseSeatMinimumDisplayName = nameof(EnterpriseSeatMinimum).GetDisplayAttribute()?.GetName() ?? nameof(EnterpriseSeatMinimum); + yield return new ValidationResult($"The {enterpriseSeatMinimumDisplayName} field can not be negative."); + } + if (Plan != PlanType.EnterpriseAnnually && Plan != PlanType.EnterpriseMonthly) + { + var planDisplayName = nameof(Plan).GetDisplayAttribute()?.GetName() ?? nameof(Plan); + yield return new ValidationResult($"The {planDisplayName} field must be set to Enterprise Annually or Enterprise Monthly."); + } + } +} diff --git a/src/Admin/AdminConsole/Models/CreateProviderModel.cs b/src/Admin/AdminConsole/Models/CreateProviderModel.cs index 07bb1b6e4c..da73787a9c 100644 --- a/src/Admin/AdminConsole/Models/CreateProviderModel.cs +++ b/src/Admin/AdminConsole/Models/CreateProviderModel.cs @@ -1,84 +1,8 @@ -using System.ComponentModel.DataAnnotations; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.SharedWeb.Utilities; +using Bit.Core.AdminConsole.Enums.Provider; namespace Bit.Admin.AdminConsole.Models; -public class CreateProviderModel : IValidatableObject +public class CreateProviderModel { - public CreateProviderModel() { } - - [Display(Name = "Provider Type")] public ProviderType Type { get; set; } - - [Display(Name = "Owner Email")] - public string OwnerEmail { get; set; } - - [Display(Name = "Name")] - public string Name { get; set; } - - [Display(Name = "Business Name")] - public string BusinessName { get; set; } - - [Display(Name = "Primary Billing Email")] - public string BillingEmail { get; set; } - - [Display(Name = "Teams (Monthly) Seat Minimum")] - public int TeamsMonthlySeatMinimum { get; set; } - - [Display(Name = "Enterprise (Monthly) Seat Minimum")] - public int EnterpriseMonthlySeatMinimum { get; set; } - - public virtual Provider ToProvider() - { - return new Provider() - { - Type = Type, - Name = Name, - BusinessName = BusinessName, - BillingEmail = BillingEmail?.ToLowerInvariant().Trim() - }; - } - - public IEnumerable Validate(ValidationContext validationContext) - { - switch (Type) - { - case ProviderType.Msp: - if (string.IsNullOrWhiteSpace(OwnerEmail)) - { - var ownerEmailDisplayName = nameof(OwnerEmail).GetDisplayAttribute()?.GetName() ?? nameof(OwnerEmail); - yield return new ValidationResult($"The {ownerEmailDisplayName} field is required."); - } - if (TeamsMonthlySeatMinimum < 0) - { - var teamsMinimumSeatsDisplayName = nameof(TeamsMonthlySeatMinimum).GetDisplayAttribute()?.GetName() ?? nameof(TeamsMonthlySeatMinimum); - yield return new ValidationResult($"The {teamsMinimumSeatsDisplayName} field can not be negative."); - } - if (EnterpriseMonthlySeatMinimum < 0) - { - var enterpriseMinimumSeatsDisplayName = nameof(EnterpriseMonthlySeatMinimum).GetDisplayAttribute()?.GetName() ?? nameof(EnterpriseMonthlySeatMinimum); - yield return new ValidationResult($"The {enterpriseMinimumSeatsDisplayName} field can not be negative."); - } - break; - case ProviderType.Reseller: - if (string.IsNullOrWhiteSpace(Name)) - { - var nameDisplayName = nameof(Name).GetDisplayAttribute()?.GetName() ?? nameof(Name); - yield return new ValidationResult($"The {nameDisplayName} field is required."); - } - if (string.IsNullOrWhiteSpace(BusinessName)) - { - var businessNameDisplayName = nameof(BusinessName).GetDisplayAttribute()?.GetName() ?? nameof(BusinessName); - yield return new ValidationResult($"The {businessNameDisplayName} field is required."); - } - if (string.IsNullOrWhiteSpace(BillingEmail)) - { - var billingEmailDisplayName = nameof(BillingEmail).GetDisplayAttribute()?.GetName() ?? nameof(BillingEmail); - yield return new ValidationResult($"The {billingEmailDisplayName} field is required."); - } - break; - } - } } diff --git a/src/Admin/AdminConsole/Models/CreateResellerProviderModel.cs b/src/Admin/AdminConsole/Models/CreateResellerProviderModel.cs new file mode 100644 index 0000000000..958faf3f85 --- /dev/null +++ b/src/Admin/AdminConsole/Models/CreateResellerProviderModel.cs @@ -0,0 +1,48 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.SharedWeb.Utilities; + +namespace Bit.Admin.AdminConsole.Models; + +public class CreateResellerProviderModel : IValidatableObject +{ + [Display(Name = "Name")] + public string Name { get; set; } + + [Display(Name = "Business Name")] + public string BusinessName { get; set; } + + [Display(Name = "Primary Billing Email")] + public string BillingEmail { get; set; } + + public virtual Provider ToProvider() + { + return new Provider + { + Name = Name, + BusinessName = BusinessName, + BillingEmail = BillingEmail?.ToLowerInvariant().Trim(), + Type = ProviderType.Reseller + }; + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (string.IsNullOrWhiteSpace(Name)) + { + var nameDisplayName = nameof(Name).GetDisplayAttribute()?.GetName() ?? nameof(Name); + yield return new ValidationResult($"The {nameDisplayName} field is required."); + } + if (string.IsNullOrWhiteSpace(BusinessName)) + { + var businessNameDisplayName = nameof(BusinessName).GetDisplayAttribute()?.GetName() ?? nameof(BusinessName); + yield return new ValidationResult($"The {businessNameDisplayName} field is required."); + } + if (string.IsNullOrWhiteSpace(BillingEmail)) + { + var billingEmailDisplayName = nameof(BillingEmail).GetDisplayAttribute()?.GetName() ?? nameof(BillingEmail); + yield return new ValidationResult($"The {billingEmailDisplayName} field is required."); + } + } +} diff --git a/src/Admin/AdminConsole/Views/Providers/Create.cshtml b/src/Admin/AdminConsole/Views/Providers/Create.cshtml index 41855895e1..8f43a4f85e 100644 --- a/src/Admin/AdminConsole/Views/Providers/Create.cshtml +++ b/src/Admin/AdminConsole/Views/Providers/Create.cshtml @@ -1,80 +1,48 @@ @using Bit.SharedWeb.Utilities @using Bit.Core.AdminConsole.Enums.Provider @using Bit.Core + @model CreateProviderModel + @inject Bit.Core.Services.IFeatureService FeatureService + @{ ViewData["Title"] = "Create Provider"; -} -@section Scripts { - + var providerTypes = Enum.GetValues() + .OrderBy(x => x.GetDisplayAttribute().Order) + .ToList(); + + if (!FeatureService.IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises)) + { + providerTypes.Remove(ProviderType.MultiOrganizationEnterprise); + } }

Create Provider

- -
+
-
- @foreach(ProviderType providerType in Enum.GetValues(typeof(ProviderType))) + @foreach (var providerType in providerTypes) { var providerTypeValue = (int)providerType; -
- @Html.RadioButtonFor(m => m.Type, providerType, new { id = $"providerType-{providerTypeValue}", @class = "form-check-input", onclick=$"toggleProviderTypeInfo({providerTypeValue})" }) - @Html.LabelFor(m => m.Type, providerType.GetDisplayAttribute()?.GetName(), new { @class = "form-check-label align-middle", @for = $"providerType-{providerTypeValue}" }) -
- @Html.LabelFor(m => m.Type, providerType.GetDisplayAttribute()?.GetDescription(), new { @class = "form-check-label small text-muted ml-3 align-top", @for = $"providerType-{providerTypeValue}" }) -
- } -
- -
-

MSP Info

-
- - -
- @if (FeatureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) - { -
-
-
- - +
+
+
+
+ @Html.RadioButtonFor(m => m.Type, providerType, new { id = $"providerType-{providerTypeValue}", @class = "form-check-input" }) + @Html.LabelFor(m => m.Type, providerType.GetDisplayAttribute()?.GetName(), new { @class = "form-check-label align-middle", @for = $"providerType-{providerTypeValue}" }) +
-
-
- - +
+
+ @Html.LabelFor(m => m.Type, providerType.GetDisplayAttribute()?.GetDescription(), new { @class = "form-check-label small text-muted align-top", @for = $"providerType-{providerTypeValue}" })
}
- -
-

Reseller Info

-
- - -
-
- - -
-
- - -
-
- - + diff --git a/src/Admin/AdminConsole/Views/Providers/CreateMsp.cshtml b/src/Admin/AdminConsole/Views/Providers/CreateMsp.cshtml new file mode 100644 index 0000000000..dde62b58a9 --- /dev/null +++ b/src/Admin/AdminConsole/Views/Providers/CreateMsp.cshtml @@ -0,0 +1,39 @@ +@using Bit.Core.AdminConsole.Enums.Provider +@using Bit.Core + +@model CreateMspProviderModel + +@inject Bit.Core.Services.IFeatureService FeatureService + +@{ + ViewData["Title"] = "Create Managed Service Provider"; +} + +

Create Managed Service Provider

+
+
+
+
+ + +
+ @if (FeatureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + { +
+
+
+ + +
+
+
+
+ + +
+
+
+ } + +
+
diff --git a/src/Admin/AdminConsole/Views/Providers/CreateMultiOrganizationEnterprise.cshtml b/src/Admin/AdminConsole/Views/Providers/CreateMultiOrganizationEnterprise.cshtml new file mode 100644 index 0000000000..997fa32ef6 --- /dev/null +++ b/src/Admin/AdminConsole/Views/Providers/CreateMultiOrganizationEnterprise.cshtml @@ -0,0 +1,43 @@ +@using Bit.Core.Billing.Enums +@using Microsoft.AspNetCore.Mvc.TagHelpers + +@model CreateMultiOrganizationEnterpriseProviderModel + +@{ + ViewData["Title"] = "Create Multi-organization Enterprise Provider"; +} + +

Create Multi-organization Enterprise Provider

+
+
+
+
+ + +
+
+
+
+ @{ + var multiOrgPlans = new List + { + PlanType.EnterpriseAnnually, + PlanType.EnterpriseMonthly + }; + } + + +
+
+
+
+ + +
+
+
+ +
+
diff --git a/src/Admin/AdminConsole/Views/Providers/CreateReseller.cshtml b/src/Admin/AdminConsole/Views/Providers/CreateReseller.cshtml new file mode 100644 index 0000000000..320ff7a4bd --- /dev/null +++ b/src/Admin/AdminConsole/Views/Providers/CreateReseller.cshtml @@ -0,0 +1,25 @@ +@model CreateResellerProviderModel + +@{ + ViewData["Title"] = "Create Reseller Provider"; +} + +

Create Reseller Provider

+
+
+
+
+ + +
+
+ + +
+
+ + +
+ +
+
diff --git a/src/Admin/Enums/HtmlHelperExtensions.cs b/src/Admin/Enums/HtmlHelperExtensions.cs new file mode 100644 index 0000000000..a5fb893030 --- /dev/null +++ b/src/Admin/Enums/HtmlHelperExtensions.cs @@ -0,0 +1,19 @@ + +using Bit.SharedWeb.Utilities; + +// ReSharper disable once CheckNamespace +namespace Microsoft.AspNetCore.Mvc.Rendering; + +public static class HtmlHelper +{ + public static IEnumerable GetEnumSelectList(this IHtmlHelper htmlHelper, IEnumerable values) + where T : Enum + { + return values.Select(v => new SelectListItem + { + Text = v.GetDisplayAttribute().Name, + Value = v.ToString() + }); + } + +} diff --git a/src/Core/AdminConsole/Enums/Provider/ProviderType.cs b/src/Core/AdminConsole/Enums/Provider/ProviderType.cs index a159fe2b6b..50c344ec95 100644 --- a/src/Core/AdminConsole/Enums/Provider/ProviderType.cs +++ b/src/Core/AdminConsole/Enums/Provider/ProviderType.cs @@ -4,8 +4,10 @@ namespace Bit.Core.AdminConsole.Enums.Provider; public enum ProviderType : byte { - [Display(ShortName = "MSP", Name = "Managed Service Provider", Description = "Access to clients organization")] + [Display(ShortName = "MSP", Name = "Managed Service Provider", Description = "Access to clients organization", Order = 0)] Msp = 0, - [Display(ShortName = "Reseller", Name = "Reseller", Description = "Access to clients billing")] + [Display(ShortName = "Reseller", Name = "Reseller", Description = "Access to clients billing", Order = 1000)] Reseller = 1, + [Display(ShortName = "MOE", Name = "Multi-organization Enterprise", Description = "Access to multiple organizations", Order = 1)] + MultiOrganizationEnterprise = 2, } diff --git a/src/Core/AdminConsole/Providers/Interfaces/ICreateProviderCommand.cs b/src/Core/AdminConsole/Providers/Interfaces/ICreateProviderCommand.cs index 800ec14055..bea3c08a85 100644 --- a/src/Core/AdminConsole/Providers/Interfaces/ICreateProviderCommand.cs +++ b/src/Core/AdminConsole/Providers/Interfaces/ICreateProviderCommand.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Enums; namespace Bit.Core.AdminConsole.Providers.Interfaces; @@ -6,4 +7,5 @@ public interface ICreateProviderCommand { Task CreateMspAsync(Provider provider, string ownerEmail, int teamsMinimumSeats, int enterpriseMinimumSeats); Task CreateResellerAsync(Provider provider); + Task CreateMultiOrganizationEnterpriseAsync(Provider provider, string ownerEmail, PlanType plan, int minimumSeats); } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index f193f7995a..b22e2cf91f 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -146,6 +146,7 @@ public static class FeatureFlagKeys public const string RemoveServerVersionHeader = "remove-server-version-header"; public const string AccessIntelligence = "pm-13227-access-intelligence"; public const string VerifiedSsoDomainEndpoint = "pm-12337-refactor-sso-details-endpoint"; + public const string PM12275_MultiOrganizationEnterprises = "pm-12275-multi-organization-enterprises"; public const string Pm13322AddPolicyDefinitions = "pm-13322-add-policy-definitions"; public const string LimitCollectionCreationDeletionSplit = "pm-10863-limit-collection-creation-deletion-split"; public const string GeneratorToolsModernization = "generator-tools-modernization"; diff --git a/test/Admin.Test/AdminConsole/Controllers/ProvidersControllerTests.cs b/test/Admin.Test/AdminConsole/Controllers/ProvidersControllerTests.cs new file mode 100644 index 0000000000..be9883ba07 --- /dev/null +++ b/test/Admin.Test/AdminConsole/Controllers/ProvidersControllerTests.cs @@ -0,0 +1,251 @@ +using Bit.Admin.AdminConsole.Controllers; +using Bit.Admin.AdminConsole.Models; +using Bit.Core; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Providers.Interfaces; +using Bit.Core.Billing.Enums; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Mvc; +using NSubstitute; +using NSubstitute.ReceivedExtensions; + +namespace Admin.Test.AdminConsole.Controllers; + +[ControllerCustomize(typeof(ProvidersController))] +[SutProviderCustomize] +public class ProvidersControllerTests +{ + #region CreateMspAsync + [BitAutoData] + [SutProviderCustomize] + [Theory] + public async Task CreateMspAsync_WithValidModel_CreatesProvider( + CreateMspProviderModel model, + SutProvider sutProvider) + { + // Arrange + + // Act + var actual = await sutProvider.Sut.CreateMsp(model); + + // Assert + Assert.NotNull(actual); + await sutProvider.GetDependency() + .Received(Quantity.Exactly(1)) + .CreateMspAsync( + Arg.Is(x => x.Type == ProviderType.Msp), + model.OwnerEmail, + model.TeamsMonthlySeatMinimum, + model.EnterpriseMonthlySeatMinimum); + } + + [BitAutoData] + [SutProviderCustomize] + [Theory] + public async Task CreateMspAsync_RedirectsToExpectedPage_AfterCreatingProvider( + CreateMspProviderModel model, + Guid expectedProviderId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .When(x => + x.CreateMspAsync( + Arg.Is(y => y.Type == ProviderType.Msp), + model.OwnerEmail, + model.TeamsMonthlySeatMinimum, + model.EnterpriseMonthlySeatMinimum)) + .Do(callInfo => + { + var providerArgument = callInfo.ArgAt(0); + providerArgument.Id = expectedProviderId; + }); + + // Act + var actual = await sutProvider.Sut.CreateMsp(model); + + // Assert + Assert.NotNull(actual); + Assert.IsType(actual); + var actualResult = (RedirectToActionResult)actual; + Assert.Equal("Edit", actualResult.ActionName); + Assert.Null(actualResult.ControllerName); + Assert.Equal(expectedProviderId, actualResult.RouteValues["Id"]); + } + #endregion + + #region CreateMultiOrganizationEnterpriseAsync + [BitAutoData] + [SutProviderCustomize] + [Theory] + public async Task CreateMultiOrganizationEnterpriseAsync_WithValidModel_CreatesProvider( + CreateMultiOrganizationEnterpriseProviderModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises) + .Returns(true); + + // Act + var actual = await sutProvider.Sut.CreateMultiOrganizationEnterprise(model); + + // Assert + Assert.NotNull(actual); + await sutProvider.GetDependency() + .Received(Quantity.Exactly(1)) + .CreateMultiOrganizationEnterpriseAsync( + Arg.Is(x => x.Type == ProviderType.MultiOrganizationEnterprise), + model.OwnerEmail, + Arg.Is(y => y == model.Plan), + model.EnterpriseSeatMinimum); + sutProvider.GetDependency() + .Received(Quantity.Exactly(1)) + .IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises); + } + + [BitAutoData] + [SutProviderCustomize] + [Theory] + public async Task CreateMultiOrganizationEnterpriseAsync_RedirectsToExpectedPage_AfterCreatingProvider( + CreateMultiOrganizationEnterpriseProviderModel model, + Guid expectedProviderId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .When(x => + x.CreateMultiOrganizationEnterpriseAsync( + Arg.Is(y => y.Type == ProviderType.MultiOrganizationEnterprise), + model.OwnerEmail, + Arg.Is(y => y == model.Plan), + model.EnterpriseSeatMinimum)) + .Do(callInfo => + { + var providerArgument = callInfo.ArgAt(0); + providerArgument.Id = expectedProviderId; + }); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises) + .Returns(true); + + // Act + var actual = await sutProvider.Sut.CreateMultiOrganizationEnterprise(model); + + // Assert + Assert.NotNull(actual); + Assert.IsType(actual); + var actualResult = (RedirectToActionResult)actual; + Assert.Equal("Edit", actualResult.ActionName); + Assert.Null(actualResult.ControllerName); + Assert.Equal(expectedProviderId, actualResult.RouteValues["Id"]); + } + + [BitAutoData] + [SutProviderCustomize] + [Theory] + public async Task CreateMultiOrganizationEnterpriseAsync_ChecksFeatureFlag( + CreateMultiOrganizationEnterpriseProviderModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises) + .Returns(true); + + // Act + await sutProvider.Sut.CreateMultiOrganizationEnterprise(model); + + // Assert + sutProvider.GetDependency() + .Received(Quantity.Exactly(1)) + .IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises); + } + + [BitAutoData] + [SutProviderCustomize] + [Theory] + public async Task CreateMultiOrganizationEnterpriseAsync_RedirectsToProviderTypeSelectionPage_WhenFeatureFlagIsDisabled( + CreateMultiOrganizationEnterpriseProviderModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises) + .Returns(false); + + // Act + var actual = await sutProvider.Sut.CreateMultiOrganizationEnterprise(model); + + // Assert + sutProvider.GetDependency() + .Received(Quantity.Exactly(1)) + .IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises); + + Assert.IsType(actual); + var actualResult = (RedirectToActionResult)actual; + Assert.Equal("Create", actualResult.ActionName); + Assert.Null(actualResult.ControllerName); + } + #endregion + + #region CreateResellerAsync + [BitAutoData] + [SutProviderCustomize] + [Theory] + public async Task CreateResellerAsync_WithValidModel_CreatesProvider( + CreateResellerProviderModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PM12275_MultiOrganizationEnterprises) + .Returns(true); + + // Act + var actual = await sutProvider.Sut.CreateReseller(model); + + // Assert + Assert.NotNull(actual); + await sutProvider.GetDependency() + .Received(Quantity.Exactly(1)) + .CreateResellerAsync( + Arg.Is(x => x.Type == ProviderType.Reseller)); + } + + [BitAutoData] + [SutProviderCustomize] + [Theory] + public async Task CreateResellerAsync_RedirectsToExpectedPage_AfterCreatingProvider( + CreateResellerProviderModel model, + Guid expectedProviderId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .When(x => + x.CreateResellerAsync( + Arg.Is(y => y.Type == ProviderType.Reseller))) + .Do(callInfo => + { + var providerArgument = callInfo.ArgAt(0); + providerArgument.Id = expectedProviderId; + }); + + // Act + var actual = await sutProvider.Sut.CreateReseller(model); + + // Assert + Assert.NotNull(actual); + Assert.IsType(actual); + var actualResult = (RedirectToActionResult)actual; + Assert.Equal("Edit", actualResult.ActionName); + Assert.Null(actualResult.ControllerName); + Assert.Equal(expectedProviderId, actualResult.RouteValues["Id"]); + } + #endregion +} From e6245bbece2671b3c2578633268ffde63ab1f1a8 Mon Sep 17 00:00:00 2001 From: Jared Snider <116684653+JaredSnider-Bitwarden@users.noreply.github.com> Date: Wed, 23 Oct 2024 18:06:24 -0400 Subject: [PATCH 3/5] Auth/PM-12613 - Registration with Email Verification - Provider Invite Flow (#4917) * PM-12613 - Add RegisterUserViaProviderInviteToken flow (needs manual, unit, and integration tests) * PM-12613 - RegisterUserCommandTests - test register via provider inv * PM-12613 - AccountsControllerTests.cs - Add integration test for provider * PM-12613 - Remove comment * PM-12613 - Add temp logging to help debug integration test failure in pipeline * PM-12613 - WebApplicationFactoryBase.cs - add ConfigureServices * PM-12613 - AccountsControllerTests.cs - refactor test to sidestep encryption * PM-12613 - Per PR feedback, refactor AccountsController.cs and move token type checking into request model. * PM-12613 - Remove debug writelines * PM-12613 - Add RegisterFinishRequestModelTests --- .../Accounts/RegisterFinishRequestModel.cs | 38 ++++ .../Registration/IRegisterUserCommand.cs | 12 ++ .../Implementations/RegisterUserCommand.cs | 31 ++++ .../Controllers/AccountsController.cs | 69 +++---- .../RegisterFinishRequestModelTests.cs | 173 ++++++++++++++++++ .../Registration/RegisterUserCommandTests.cs | 165 ++++++++++++++++- .../Controllers/AccountsControllerTests.cs | 80 +++++++- .../Factories/WebApplicationFactoryBase.cs | 10 + 8 files changed, 535 insertions(+), 43 deletions(-) create mode 100644 test/Core.Test/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModelTests.cs diff --git a/src/Core/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModel.cs b/src/Core/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModel.cs index 9036651fd6..0ac7dbbcb4 100644 --- a/src/Core/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModel.cs +++ b/src/Core/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModel.cs @@ -6,6 +6,14 @@ using Bit.Core.Utilities; namespace Bit.Core.Auth.Models.Api.Request.Accounts; using System.ComponentModel.DataAnnotations; +public enum RegisterFinishTokenType : byte +{ + EmailVerification = 1, + OrganizationInvite = 2, + OrgSponsoredFreeFamilyPlan = 3, + EmergencyAccessInvite = 4, + ProviderInvite = 5, +} public class RegisterFinishRequestModel : IValidatableObject { @@ -36,6 +44,10 @@ public class RegisterFinishRequestModel : IValidatableObject public string? AcceptEmergencyAccessInviteToken { get; set; } public Guid? AcceptEmergencyAccessId { get; set; } + public string? ProviderInviteToken { get; set; } + + public Guid? ProviderUserId { get; set; } + public User ToUser() { var user = new User @@ -54,6 +66,32 @@ public class RegisterFinishRequestModel : IValidatableObject return user; } + public RegisterFinishTokenType GetTokenType() + { + if (!string.IsNullOrWhiteSpace(EmailVerificationToken)) + { + return RegisterFinishTokenType.EmailVerification; + } + if (!string.IsNullOrEmpty(OrgInviteToken) && OrganizationUserId.HasValue) + { + return RegisterFinishTokenType.OrganizationInvite; + } + if (!string.IsNullOrWhiteSpace(OrgSponsoredFreeFamilyPlanToken)) + { + return RegisterFinishTokenType.OrgSponsoredFreeFamilyPlan; + } + if (!string.IsNullOrWhiteSpace(AcceptEmergencyAccessInviteToken) && AcceptEmergencyAccessId.HasValue) + { + return RegisterFinishTokenType.EmergencyAccessInvite; + } + if (!string.IsNullOrWhiteSpace(ProviderInviteToken) && ProviderUserId.HasValue) + { + return RegisterFinishTokenType.ProviderInvite; + } + + throw new InvalidOperationException("Invalid token type."); + } + public IEnumerable Validate(ValidationContext validationContext) { diff --git a/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs b/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs index d507cda4ed..f61cce895a 100644 --- a/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs @@ -61,4 +61,16 @@ public interface IRegisterUserCommand public Task RegisterUserViaAcceptEmergencyAccessInviteToken(User user, string masterPasswordHash, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId); + /// + /// Creates a new user with a given master password hash, sends a welcome email, and raises the signup reference event. + /// If a valid token is provided, the user will be created with their email verified. + /// If the token is invalid or expired, an error will be thrown. + /// + /// The to create + /// The hashed master password the user entered + /// The provider invite token sent to the user via email + /// The provider user id which is used to validate the invite token + /// + public Task RegisterUserViaProviderInviteToken(User user, string masterPasswordHash, string providerInviteToken, Guid providerUserId); + } diff --git a/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs b/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs index 3bbdaaf0af..8174d7d364 100644 --- a/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs @@ -32,6 +32,7 @@ public class RegisterUserCommand : IRegisterUserCommand private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; private readonly IDataProtectorTokenFactory _registrationEmailVerificationTokenDataFactory; private readonly IDataProtector _organizationServiceDataProtector; + private readonly IDataProtector _providerServiceDataProtector; private readonly ICurrentContext _currentContext; @@ -75,6 +76,8 @@ public class RegisterUserCommand : IRegisterUserCommand _validateRedemptionTokenCommand = validateRedemptionTokenCommand; _emergencyAccessInviteTokenDataFactory = emergencyAccessInviteTokenDataFactory; + + _providerServiceDataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); } @@ -303,6 +306,25 @@ public class RegisterUserCommand : IRegisterUserCommand return result; } + public async Task RegisterUserViaProviderInviteToken(User user, string masterPasswordHash, + string providerInviteToken, Guid providerUserId) + { + ValidateOpenRegistrationAllowed(); + ValidateProviderInviteToken(providerInviteToken, providerUserId, user.Email); + + user.EmailVerified = true; + user.ApiKey = CoreHelpers.SecureRandomString(30); // API key can't be null. + + var result = await _userService.CreateUserAsync(user, masterPasswordHash); + if (result == IdentityResult.Success) + { + await _mailService.SendWelcomeEmailAsync(user); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user, _currentContext)); + } + + return result; + } + private void ValidateOpenRegistrationAllowed() { // We validate open registration on send of initial email and here b/c a user could technically start the @@ -333,6 +355,15 @@ public class RegisterUserCommand : IRegisterUserCommand } } + private void ValidateProviderInviteToken(string providerInviteToken, Guid providerUserId, string userEmail) + { + if (!CoreHelpers.TokenIsValid("ProviderUserInvite", _providerServiceDataProtector, providerInviteToken, userEmail, providerUserId, + _globalSettings.OrganizationInviteExpirationHours)) + { + throw new BadRequestException("Invalid provider invite token."); + } + } + private RegistrationEmailVerificationTokenable ValidateRegistrationEmailVerificationTokenable(string emailVerificationToken, string userEmail) { diff --git a/src/Identity/Controllers/AccountsController.cs b/src/Identity/Controllers/AccountsController.cs index 38316566c6..40c926bda0 100644 --- a/src/Identity/Controllers/AccountsController.cs +++ b/src/Identity/Controllers/AccountsController.cs @@ -1,4 +1,5 @@ -using Bit.Core; +using System.Diagnostics; +using Bit.Core; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Api.Request.Accounts; using Bit.Core.Auth.Models.Api.Response.Accounts; @@ -149,40 +150,44 @@ public class AccountsController : Controller IdentityResult identityResult = null; var delaysEnabled = !_featureService.IsEnabled(FeatureFlagKeys.EmailVerificationDisableTimingDelays); - if (!string.IsNullOrEmpty(model.OrgInviteToken) && model.OrganizationUserId.HasValue) + switch (model.GetTokenType()) { - identityResult = await _registerUserCommand.RegisterUserViaOrganizationInviteToken(user, model.MasterPasswordHash, - model.OrgInviteToken, model.OrganizationUserId); + case RegisterFinishTokenType.EmailVerification: + identityResult = + await _registerUserCommand.RegisterUserViaEmailVerificationToken(user, model.MasterPasswordHash, + model.EmailVerificationToken); - return await ProcessRegistrationResult(identityResult, user, delaysEnabled); + return await ProcessRegistrationResult(identityResult, user, delaysEnabled); + break; + case RegisterFinishTokenType.OrganizationInvite: + identityResult = await _registerUserCommand.RegisterUserViaOrganizationInviteToken(user, model.MasterPasswordHash, + model.OrgInviteToken, model.OrganizationUserId); + + return await ProcessRegistrationResult(identityResult, user, delaysEnabled); + break; + case RegisterFinishTokenType.OrgSponsoredFreeFamilyPlan: + identityResult = await _registerUserCommand.RegisterUserViaOrganizationSponsoredFreeFamilyPlanInviteToken(user, model.MasterPasswordHash, model.OrgSponsoredFreeFamilyPlanToken); + + return await ProcessRegistrationResult(identityResult, user, delaysEnabled); + break; + case RegisterFinishTokenType.EmergencyAccessInvite: + Debug.Assert(model.AcceptEmergencyAccessId.HasValue); + identityResult = await _registerUserCommand.RegisterUserViaAcceptEmergencyAccessInviteToken(user, model.MasterPasswordHash, + model.AcceptEmergencyAccessInviteToken, model.AcceptEmergencyAccessId.Value); + + return await ProcessRegistrationResult(identityResult, user, delaysEnabled); + break; + case RegisterFinishTokenType.ProviderInvite: + Debug.Assert(model.ProviderUserId.HasValue); + identityResult = await _registerUserCommand.RegisterUserViaProviderInviteToken(user, model.MasterPasswordHash, + model.ProviderInviteToken, model.ProviderUserId.Value); + + return await ProcessRegistrationResult(identityResult, user, delaysEnabled); + break; + + default: + throw new BadRequestException("Invalid registration finish request"); } - - if (!string.IsNullOrEmpty(model.OrgSponsoredFreeFamilyPlanToken)) - { - identityResult = await _registerUserCommand.RegisterUserViaOrganizationSponsoredFreeFamilyPlanInviteToken(user, model.MasterPasswordHash, model.OrgSponsoredFreeFamilyPlanToken); - - return await ProcessRegistrationResult(identityResult, user, delaysEnabled); - } - - if (!string.IsNullOrEmpty(model.AcceptEmergencyAccessInviteToken) && model.AcceptEmergencyAccessId.HasValue) - { - identityResult = await _registerUserCommand.RegisterUserViaAcceptEmergencyAccessInviteToken(user, model.MasterPasswordHash, - model.AcceptEmergencyAccessInviteToken, model.AcceptEmergencyAccessId.Value); - - return await ProcessRegistrationResult(identityResult, user, delaysEnabled); - } - - if (string.IsNullOrEmpty(model.EmailVerificationToken)) - { - throw new BadRequestException("Invalid registration finish request"); - } - - identityResult = - await _registerUserCommand.RegisterUserViaEmailVerificationToken(user, model.MasterPasswordHash, - model.EmailVerificationToken); - - return await ProcessRegistrationResult(identityResult, user, delaysEnabled); - } private async Task ProcessRegistrationResult(IdentityResult result, User user, bool delaysEnabled) diff --git a/test/Core.Test/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModelTests.cs b/test/Core.Test/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModelTests.cs new file mode 100644 index 0000000000..588ca878fc --- /dev/null +++ b/test/Core.Test/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModelTests.cs @@ -0,0 +1,173 @@ +using Bit.Core.Auth.Models.Api.Request.Accounts; +using Bit.Core.Enums; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Core.Test.Auth.Models.Api.Request.Accounts; + +public class RegisterFinishRequestModelTests +{ + [Theory] + [BitAutoData] + public void GetTokenType_Returns_EmailVerification(string email, string masterPasswordHash, + string userSymmetricKey, KeysRequestModel userAsymmetricKeys, KdfType kdf, int kdfIterations, string emailVerificationToken) + { + // Arrange + var model = new RegisterFinishRequestModel + { + Email = email, + MasterPasswordHash = masterPasswordHash, + UserSymmetricKey = userSymmetricKey, + UserAsymmetricKeys = userAsymmetricKeys, + Kdf = kdf, + KdfIterations = kdfIterations, + EmailVerificationToken = emailVerificationToken + }; + + // Act + Assert.Equal(RegisterFinishTokenType.EmailVerification, model.GetTokenType()); + } + + [Theory] + [BitAutoData] + public void GetTokenType_Returns_OrganizationInvite(string email, string masterPasswordHash, + string userSymmetricKey, KeysRequestModel userAsymmetricKeys, KdfType kdf, int kdfIterations, string orgInviteToken, Guid organizationUserId) + { + // Arrange + var model = new RegisterFinishRequestModel + { + Email = email, + MasterPasswordHash = masterPasswordHash, + UserSymmetricKey = userSymmetricKey, + UserAsymmetricKeys = userAsymmetricKeys, + Kdf = kdf, + KdfIterations = kdfIterations, + OrgInviteToken = orgInviteToken, + OrganizationUserId = organizationUserId + }; + + // Act + Assert.Equal(RegisterFinishTokenType.OrganizationInvite, model.GetTokenType()); + } + + [Theory] + [BitAutoData] + public void GetTokenType_Returns_OrgSponsoredFreeFamilyPlan(string email, string masterPasswordHash, + string userSymmetricKey, KeysRequestModel userAsymmetricKeys, KdfType kdf, int kdfIterations, string orgSponsoredFreeFamilyPlanToken) + { + // Arrange + var model = new RegisterFinishRequestModel + { + Email = email, + MasterPasswordHash = masterPasswordHash, + UserSymmetricKey = userSymmetricKey, + UserAsymmetricKeys = userAsymmetricKeys, + Kdf = kdf, + KdfIterations = kdfIterations, + OrgSponsoredFreeFamilyPlanToken = orgSponsoredFreeFamilyPlanToken + }; + + // Act + Assert.Equal(RegisterFinishTokenType.OrgSponsoredFreeFamilyPlan, model.GetTokenType()); + } + + [Theory] + [BitAutoData] + public void GetTokenType_Returns_EmergencyAccessInvite(string email, string masterPasswordHash, + string userSymmetricKey, KeysRequestModel userAsymmetricKeys, KdfType kdf, int kdfIterations, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) + { + // Arrange + var model = new RegisterFinishRequestModel + { + Email = email, + MasterPasswordHash = masterPasswordHash, + UserSymmetricKey = userSymmetricKey, + UserAsymmetricKeys = userAsymmetricKeys, + Kdf = kdf, + KdfIterations = kdfIterations, + AcceptEmergencyAccessInviteToken = acceptEmergencyAccessInviteToken, + AcceptEmergencyAccessId = acceptEmergencyAccessId + }; + + // Act + Assert.Equal(RegisterFinishTokenType.EmergencyAccessInvite, model.GetTokenType()); + } + + [Theory] + [BitAutoData] + public void GetTokenType_Returns_ProviderInvite(string email, string masterPasswordHash, + string userSymmetricKey, KeysRequestModel userAsymmetricKeys, KdfType kdf, int kdfIterations, string providerInviteToken, Guid providerUserId) + { + // Arrange + var model = new RegisterFinishRequestModel + { + Email = email, + MasterPasswordHash = masterPasswordHash, + UserSymmetricKey = userSymmetricKey, + UserAsymmetricKeys = userAsymmetricKeys, + Kdf = kdf, + KdfIterations = kdfIterations, + ProviderInviteToken = providerInviteToken, + ProviderUserId = providerUserId + }; + + // Act + Assert.Equal(RegisterFinishTokenType.ProviderInvite, model.GetTokenType()); + } + + [Theory] + [BitAutoData] + public void GetTokenType_Returns_Invalid(string email, string masterPasswordHash, + string userSymmetricKey, KeysRequestModel userAsymmetricKeys, KdfType kdf, int kdfIterations) + { + // Arrange + var model = new RegisterFinishRequestModel + { + Email = email, + MasterPasswordHash = masterPasswordHash, + UserSymmetricKey = userSymmetricKey, + UserAsymmetricKeys = userAsymmetricKeys, + Kdf = kdf, + KdfIterations = kdfIterations + }; + + // Act + var result = Assert.Throws(() => model.GetTokenType()); + Assert.Equal("Invalid token type.", result.Message); + } + + [Theory] + [BitAutoData] + public void ToUser_Returns_User(string email, string masterPasswordHash, string masterPasswordHint, + string userSymmetricKey, KeysRequestModel userAsymmetricKeys, KdfType kdf, int kdfIterations, + int? kdfMemory, int? kdfParallelism) + { + // Arrange + var model = new RegisterFinishRequestModel + { + Email = email, + MasterPasswordHash = masterPasswordHash, + MasterPasswordHint = masterPasswordHint, + UserSymmetricKey = userSymmetricKey, + UserAsymmetricKeys = userAsymmetricKeys, + Kdf = kdf, + KdfIterations = kdfIterations, + KdfMemory = kdfMemory, + KdfParallelism = kdfParallelism + }; + + // Act + var result = model.ToUser(); + + // Assert + Assert.Equal(email, result.Email); + Assert.Equal(masterPasswordHint, result.MasterPasswordHint); + Assert.Equal(kdf, result.Kdf); + Assert.Equal(kdfIterations, result.KdfIterations); + Assert.Equal(kdfMemory, result.KdfMemory); + Assert.Equal(kdfParallelism, result.KdfParallelism); + Assert.Equal(userSymmetricKey, result.Key); + Assert.Equal(userAsymmetricKeys.PublicKey, result.PublicKey); + Assert.Equal(userAsymmetricKeys.EncryptedPrivateKey, result.PrivateKey); + } +} diff --git a/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs b/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs index e96e3553df..02ecb4ecd7 100644 --- a/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs +++ b/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs @@ -1,4 +1,5 @@ -using Bit.Core.AdminConsole.Entities; +using System.Text; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Entities; @@ -19,7 +20,9 @@ using Bit.Core.Tools.Services; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.Identity; +using Microsoft.AspNetCore.WebUtilities; using NSubstitute; using Xunit; @@ -28,8 +31,10 @@ namespace Bit.Core.Test.Auth.UserFeatures.Registration; [SutProviderCustomize] public class RegisterUserCommandTests { - + // ----------------------------------------------------------------------------------------------- // RegisterUser tests + // ----------------------------------------------------------------------------------------------- + [Theory] [BitAutoData] public async Task RegisterUser_Succeeds(SutProvider sutProvider, User user) @@ -86,7 +91,10 @@ public class RegisterUserCommandTests .RaiseEventAsync(Arg.Any()); } + // ----------------------------------------------------------------------------------------------- // RegisterUserWithOrganizationInviteToken tests + // ----------------------------------------------------------------------------------------------- + // Simple happy path test [Theory] [BitAutoData] @@ -312,7 +320,10 @@ public class RegisterUserCommandTests Assert.Equal(expectedErrorMessage, exception.Message); } - // RegisterUserViaEmailVerificationToken + // ----------------------------------------------------------------------------------------------- + // RegisterUserViaEmailVerificationToken tests + // ----------------------------------------------------------------------------------------------- + [Theory] [BitAutoData] public async Task RegisterUserViaEmailVerificationToken_Succeeds(SutProvider sutProvider, User user, string masterPasswordHash, string emailVerificationToken, bool receiveMarketingMaterials) @@ -382,10 +393,9 @@ public class RegisterUserCommandTests } - - - // RegisterUserViaOrganizationSponsoredFreeFamilyPlanInviteToken - + // ----------------------------------------------------------------------------------------------- + // RegisterUserViaOrganizationSponsoredFreeFamilyPlanInviteToken tests + // ----------------------------------------------------------------------------------------------- [Theory] [BitAutoData] @@ -452,7 +462,9 @@ public class RegisterUserCommandTests Assert.Equal("Open registration has been disabled by the system administrator.", result.Message); } - // RegisterUserViaAcceptEmergencyAccessInviteToken + // ----------------------------------------------------------------------------------------------- + // RegisterUserViaAcceptEmergencyAccessInviteToken tests + // ----------------------------------------------------------------------------------------------- [Theory] [BitAutoData] @@ -495,8 +507,6 @@ public class RegisterUserCommandTests .RaiseEventAsync(Arg.Is(refEvent => refEvent.Type == ReferenceEventType.Signup)); } - - [Theory] [BitAutoData] public async Task RegisterUserViaAcceptEmergencyAccessInviteToken_InvalidToken_ThrowsBadRequestException(SutProvider sutProvider, User user, @@ -536,5 +546,140 @@ public class RegisterUserCommandTests Assert.Equal("Open registration has been disabled by the system administrator.", result.Message); } + // ----------------------------------------------------------------------------------------------- + // RegisterUserViaProviderInviteToken tests + // ----------------------------------------------------------------------------------------------- + + [Theory] + [BitAutoData] + public async Task RegisterUserViaProviderInviteToken_Succeeds(SutProvider sutProvider, + User user, string masterPasswordHash, Guid providerUserId) + { + // Arrange + // Start with plaintext + var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); + var decryptedProviderInviteToken = $"ProviderUserInvite {providerUserId} {user.Email} {nowMillis}"; + + // Get the byte array of the plaintext + var decryptedProviderInviteTokenByteArray = Encoding.UTF8.GetBytes(decryptedProviderInviteToken); + + // Base64 encode the byte array (this is passed to protector.protect(bytes)) + var base64EncodedProviderInvToken = WebEncoders.Base64UrlEncode(decryptedProviderInviteTokenByteArray); + + var mockDataProtector = Substitute.For(); + + // Given any byte array, just return the decryptedProviderInviteTokenByteArray (sidestepping any actual encryption) + mockDataProtector.Unprotect(Arg.Any()).Returns(decryptedProviderInviteTokenByteArray); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(mockDataProtector); + + sutProvider.GetDependency() + .OrganizationInviteExpirationHours.Returns(120); // 5 days + + sutProvider.GetDependency() + .CreateUserAsync(user, masterPasswordHash) + .Returns(IdentityResult.Success); + + // Using sutProvider in the parameters of the function means that the constructor has already run for the + // command so we have to recreate it in order for our mock overrides to be used. + sutProvider.Create(); + + // Act + var result = await sutProvider.Sut.RegisterUserViaProviderInviteToken(user, masterPasswordHash, base64EncodedProviderInvToken, providerUserId); + + // Assert + Assert.True(result.Succeeded); + + await sutProvider.GetDependency() + .Received(1) + .CreateUserAsync(Arg.Is(u => u.Name == user.Name && u.EmailVerified == true && u.ApiKey != null), masterPasswordHash); + + await sutProvider.GetDependency() + .Received(1) + .SendWelcomeEmailAsync(user); + + await sutProvider.GetDependency() + .Received(1) + .RaiseEventAsync(Arg.Is(refEvent => refEvent.Type == ReferenceEventType.Signup)); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaProviderInviteToken_InvalidToken_ThrowsBadRequestException(SutProvider sutProvider, + User user, string masterPasswordHash, Guid providerUserId) + { + // Arrange + // Start with plaintext + var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); + var decryptedProviderInviteToken = $"ProviderUserInvite {providerUserId} {user.Email} {nowMillis}"; + + // Get the byte array of the plaintext + var decryptedProviderInviteTokenByteArray = Encoding.UTF8.GetBytes(decryptedProviderInviteToken); + + // Base64 encode the byte array (this is passed to protector.protect(bytes)) + var base64EncodedProviderInvToken = WebEncoders.Base64UrlEncode(decryptedProviderInviteTokenByteArray); + + var mockDataProtector = Substitute.For(); + + // Given any byte array, just return the decryptedProviderInviteTokenByteArray (sidestepping any actual encryption) + mockDataProtector.Unprotect(Arg.Any()).Returns(decryptedProviderInviteTokenByteArray); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(mockDataProtector); + + sutProvider.GetDependency() + .OrganizationInviteExpirationHours.Returns(120); // 5 days + + // Using sutProvider in the parameters of the function means that the constructor has already run for the + // command so we have to recreate it in order for our mock overrides to be used. + sutProvider.Create(); + + // Act & Assert + var result = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaProviderInviteToken(user, masterPasswordHash, base64EncodedProviderInvToken, Guid.NewGuid())); + Assert.Equal("Invalid provider invite token.", result.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaProviderInviteToken_DisabledOpenRegistration_ThrowsBadRequestException(SutProvider sutProvider, + User user, string masterPasswordHash, Guid providerUserId) + { + // Arrange + // Start with plaintext + var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); + var decryptedProviderInviteToken = $"ProviderUserInvite {providerUserId} {user.Email} {nowMillis}"; + + // Get the byte array of the plaintext + var decryptedProviderInviteTokenByteArray = Encoding.UTF8.GetBytes(decryptedProviderInviteToken); + + // Base64 encode the byte array (this is passed to protector.protect(bytes)) + var base64EncodedProviderInvToken = WebEncoders.Base64UrlEncode(decryptedProviderInviteTokenByteArray); + + var mockDataProtector = Substitute.For(); + + // Given any byte array, just return the decryptedProviderInviteTokenByteArray (sidestepping any actual encryption) + mockDataProtector.Unprotect(Arg.Any()).Returns(decryptedProviderInviteTokenByteArray); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(mockDataProtector); + + sutProvider.GetDependency() + .DisableUserRegistration = true; + + // Using sutProvider in the parameters of the function means that the constructor has already run for the + // command so we have to recreate it in order for our mock overrides to be used. + sutProvider.Create(); + + // Act & Assert + var result = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaProviderInviteToken(user, masterPasswordHash, base64EncodedProviderInvToken, providerUserId)); + Assert.Equal("Open registration has been disabled by the system administrator.", result.Message); + } + } diff --git a/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs b/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs index 50f7d70abf..3b8534ef32 100644 --- a/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs +++ b/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs @@ -1,4 +1,5 @@ using System.ComponentModel.DataAnnotations; +using System.Text; using Bit.Core; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Models.Api.Request.Accounts; @@ -9,10 +10,12 @@ using Bit.Core.Models.Business.Tokenables; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Tokens; +using Bit.Core.Utilities; using Bit.Identity.Models.Request.Accounts; using Bit.IntegrationTestCommon.Factories; using Bit.Test.Common.AutoFixture.Attributes; - +using Microsoft.AspNetCore.DataProtection; +using Microsoft.AspNetCore.WebUtilities; using Microsoft.EntityFrameworkCore; using NSubstitute; using Xunit; @@ -470,6 +473,80 @@ public class AccountsControllerTests : IClassFixture Assert.Equal(kdfParallelism, user.KdfParallelism); } + [Theory, BitAutoData] + public async Task RegistrationWithEmailVerification_WithProviderInviteToken_Succeeds( + [StringLength(1000)] string masterPasswordHash, [StringLength(50)] string masterPasswordHint, string userSymmetricKey, + KeysRequestModel userAsymmetricKeys, int kdfMemory, int kdfParallelism) + { + + // Localize factory to just this test. + var localFactory = new IdentityApplicationFactory(); + + // Hardcoded, valid data + var email = "jsnider+local253@bitwarden.com"; + var providerUserId = new Guid("c6fdba35-2e52-43b4-8fb7-b211011d154a"); + var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); + var decryptedProviderInviteToken = $"ProviderUserInvite {providerUserId} {email} {nowMillis}"; + // var providerInviteToken = await GetValidProviderInviteToken(localFactory, email, providerUserId); + + // Get the byte array of the plaintext + var decryptedProviderInviteTokenByteArray = Encoding.UTF8.GetBytes(decryptedProviderInviteToken); + + // Base64 encode the byte array (this is passed to protector.protect(bytes)) + var base64EncodedProviderInvToken = WebEncoders.Base64UrlEncode(decryptedProviderInviteTokenByteArray); + + var mockDataProtector = Substitute.For(); + mockDataProtector.Unprotect(Arg.Any()).Returns(decryptedProviderInviteTokenByteArray); + + localFactory.SubstituteService(dataProtectionProvider => + { + dataProtectionProvider.CreateProtector(Arg.Any()) + .Returns(mockDataProtector); + }); + + // As token contains now milliseconds for when it was created, create 1k year timespan for expiration + // to ensure token is valid for a good long while. + localFactory.UpdateConfiguration("globalSettings:OrganizationInviteExpirationHours", "8760000"); + + var registerFinishReqModel = new RegisterFinishRequestModel + { + Email = email, + MasterPasswordHash = masterPasswordHash, + MasterPasswordHint = masterPasswordHint, + ProviderInviteToken = base64EncodedProviderInvToken, + ProviderUserId = providerUserId, + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = AuthConstants.PBKDF2_ITERATIONS.Default, + UserSymmetricKey = userSymmetricKey, + UserAsymmetricKeys = userAsymmetricKeys, + KdfMemory = kdfMemory, + KdfParallelism = kdfParallelism + }; + + var postRegisterFinishHttpContext = await localFactory.PostRegisterFinishAsync(registerFinishReqModel); + + Assert.Equal(StatusCodes.Status200OK, postRegisterFinishHttpContext.Response.StatusCode); + + var database = localFactory.GetDatabaseContext(); + var user = await database.Users + .SingleAsync(u => u.Email == email); + + Assert.NotNull(user); + + // Assert user properties match the request model + Assert.Equal(email, user.Email); + Assert.NotEqual(masterPasswordHash, user.MasterPassword); // We execute server side hashing + Assert.NotNull(user.MasterPassword); + Assert.Equal(masterPasswordHint, user.MasterPasswordHint); + Assert.Equal(userSymmetricKey, user.Key); + Assert.Equal(userAsymmetricKeys.EncryptedPrivateKey, user.PrivateKey); + Assert.Equal(userAsymmetricKeys.PublicKey, user.PublicKey); + Assert.Equal(KdfType.PBKDF2_SHA256, user.Kdf); + Assert.Equal(AuthConstants.PBKDF2_ITERATIONS.Default, user.KdfIterations); + Assert.Equal(kdfMemory, user.KdfMemory); + Assert.Equal(kdfParallelism, user.KdfParallelism); + } + [Theory, BitAutoData] public async Task PostRegisterVerificationEmailClicked_Success( @@ -527,4 +604,5 @@ public class AccountsControllerTests : IClassFixture return user; } + } diff --git a/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs b/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs index aafe86d56a..3ce2599705 100644 --- a/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs +++ b/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs @@ -57,6 +57,16 @@ public abstract class WebApplicationFactoryBase : WebApplicationFactory }); } + /// + /// Allows you to add your own services to the application as required. + /// + /// The service collection you want added to the test service collection. + /// This needs to be ran BEFORE making any calls through the factory to take effect. + public void ConfigureServices(Action configure) + { + _configureTestServices.Add(configure); + } + /// /// Add your own configuration provider to the application. /// From 4a1b90db4851c617344b3d451e92a95596493c99 Mon Sep 17 00:00:00 2001 From: Thomas Rittson <31796059+eliykat@users.noreply.github.com> Date: Thu, 24 Oct 2024 08:09:07 +1000 Subject: [PATCH 4/5] Remove bulk-device-approval feature flag definition (#4930) --- src/Core/Constants.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index b22e2cf91f..d4408e7a3a 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -117,7 +117,6 @@ public static class FeatureFlagKeys public const string RestrictProviderAccess = "restrict-provider-access"; public const string PM4154BulkEncryptionService = "PM-4154-bulk-encryption-service"; public const string VaultBulkManagementAction = "vault-bulk-management-action"; - public const string BulkDeviceApproval = "bulk-device-approval"; public const string MemberAccessReport = "ac-2059-member-access-report"; public const string BlockLegacyUsers = "block-legacy-users"; public const string InlineMenuFieldQualification = "inline-menu-field-qualification"; @@ -165,7 +164,6 @@ public static class FeatureFlagKeys return new Dictionary() { { DuoRedirect, "true" }, - { BulkDeviceApproval, "true" }, { CipherKeyEncryption, "true" }, }; } From d38c489443b559141acde1ca4ad1e5338815cdef Mon Sep 17 00:00:00 2001 From: Jonas Hendrickx Date: Thu, 24 Oct 2024 08:34:27 +0200 Subject: [PATCH 5/5] [PM-13982] [Defect] Can no longer create providers due to foreign key conflict (#4935) --- .../Providers/CreateProviderCommand.cs | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/CreateProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/CreateProviderCommand.cs index d192073d4d..3b01370ef7 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/CreateProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/CreateProviderCommand.cs @@ -41,13 +41,15 @@ public class CreateProviderCommand : ICreateProviderCommand public async Task CreateMspAsync(Provider provider, string ownerEmail, int teamsMinimumSeats, int enterpriseMinimumSeats) { - var providerPlans = new List - { - CreateProviderPlan(provider.Id, PlanType.TeamsMonthly, teamsMinimumSeats), - CreateProviderPlan(provider.Id, PlanType.EnterpriseMonthly, enterpriseMinimumSeats) - }; + var providerId = await CreateProviderAsync(provider, ownerEmail); - await CreateProviderAsync(provider, ownerEmail, providerPlans); + var isConsolidatedBillingEnabled = _featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling); + + if (isConsolidatedBillingEnabled) + { + await CreateProviderPlanAsync(providerId, PlanType.TeamsMonthly, teamsMinimumSeats); + await CreateProviderPlanAsync(providerId, PlanType.EnterpriseMonthly, enterpriseMinimumSeats); + } } public async Task CreateResellerAsync(Provider provider) @@ -57,15 +59,17 @@ public class CreateProviderCommand : ICreateProviderCommand public async Task CreateMultiOrganizationEnterpriseAsync(Provider provider, string ownerEmail, PlanType plan, int minimumSeats) { - var providerPlans = new List - { - CreateProviderPlan(provider.Id, plan, minimumSeats) - }; + var providerId = await CreateProviderAsync(provider, ownerEmail); - await CreateProviderAsync(provider, ownerEmail, providerPlans); + var isConsolidatedBillingEnabled = _featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling); + + if (isConsolidatedBillingEnabled) + { + await CreateProviderPlanAsync(providerId, plan, minimumSeats); + } } - private async Task CreateProviderAsync(Provider provider, string ownerEmail, List providerPlans) + private async Task CreateProviderAsync(Provider provider, string ownerEmail) { var owner = await _userRepository.GetByEmailAsync(ownerEmail); if (owner == null) @@ -90,16 +94,10 @@ public class CreateProviderCommand : ICreateProviderCommand Status = ProviderUserStatusType.Confirmed, }; - if (isConsolidatedBillingEnabled) - { - foreach (var providerPlan in providerPlans) - { - await _providerPlanRepository.CreateAsync(providerPlan); - } - } - await _providerUserRepository.CreateAsync(providerUser); await _providerService.SendProviderSetupInviteEmailAsync(provider, owner.Email); + + return provider.Id; } private async Task ProviderRepositoryCreateAsync(Provider provider, ProviderStatusType status) @@ -110,9 +108,9 @@ public class CreateProviderCommand : ICreateProviderCommand await _providerRepository.CreateAsync(provider); } - private ProviderPlan CreateProviderPlan(Guid providerId, PlanType planType, int seatMinimum) + private async Task CreateProviderPlanAsync(Guid providerId, PlanType planType, int seatMinimum) { - return new ProviderPlan + var plan = new ProviderPlan { ProviderId = providerId, PlanType = planType, @@ -120,5 +118,6 @@ public class CreateProviderCommand : ICreateProviderCommand PurchasedSeats = 0, AllocatedSeats = 0 }; + await _providerPlanRepository.CreateAsync(plan); } }