From 26d6852388c254733687903d53c2ff477598f602 Mon Sep 17 00:00:00 2001 From: Bernd Schoolmann Date: Tue, 18 Mar 2025 16:28:02 +0100 Subject: [PATCH] Attempt to fix tests --- .../Controllers/AccountsController.cs | 4 ++-- .../OpaqueKeyExchangeCredentialRepository.cs | 23 +++++++++++++++++++ ...ityFrameworkServiceCollectionExtensions.cs | 2 ++ .../Utilities/ServiceCollectionExtensions.cs | 2 +- .../Controllers/AccountsControllerTests.cs | 15 ++++++++++++ 5 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 src/Infrastructure.EntityFramework/Auth/Repositories/OpaqueKeyExchangeCredentialRepository.cs diff --git a/src/Identity/Controllers/AccountsController.cs b/src/Identity/Controllers/AccountsController.cs index 3ffeeec2ca..0ca5d50c5d 100644 --- a/src/Identity/Controllers/AccountsController.cs +++ b/src/Identity/Controllers/AccountsController.cs @@ -261,12 +261,12 @@ public class AccountsController : Controller public async Task PostPrelogin([FromBody] PreloginRequestModel model) { var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); - if (kdfInformation == null) + var user = await _userRepository.GetByEmailAsync(model.Email); + if (kdfInformation == null || user == null) { kdfInformation = GetDefaultKdf(model.Email); } - var user = await _userRepository.GetByEmailAsync(model.Email); var credential = await _opaqueKeyExchangeCredentialRepository.GetByUserIdAsync(user.Id); if (credential != null) { diff --git a/src/Infrastructure.EntityFramework/Auth/Repositories/OpaqueKeyExchangeCredentialRepository.cs b/src/Infrastructure.EntityFramework/Auth/Repositories/OpaqueKeyExchangeCredentialRepository.cs new file mode 100644 index 0000000000..43e38f79ca --- /dev/null +++ b/src/Infrastructure.EntityFramework/Auth/Repositories/OpaqueKeyExchangeCredentialRepository.cs @@ -0,0 +1,23 @@ +using AutoMapper; +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Auth.Repositories; +using Bit.Core.KeyManagement.UserKey; +using Bit.Infrastructure.EntityFramework.Repositories; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; + +#nullable enable + + +namespace Bit.Infrastructure.Dapper.Auth.Repositories; + +public class OpaqueKeyExchangeCredentialRepository : Repository, IOpaqueKeyExchangeCredentialRepository +{ + public OpaqueKeyExchangeCredentialRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper, Func> getDbSet) : base(serviceScopeFactory, mapper, getDbSet) + { + } + + public Task GetByUserIdAsync(Guid userId) => throw new NotImplementedException(); + public UpdateEncryptedDataForKeyRotation UpdateKeysForRotationAsync(Guid userId, IEnumerable credentials) => throw new NotImplementedException(); +} diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index 3f805bbe2c..3193e88bf1 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -9,6 +9,7 @@ using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; using Bit.Core.Tools.Repositories; using Bit.Core.Vault.Repositories; +using Bit.Infrastructure.Dapper.Auth.Repositories; using Bit.Infrastructure.EntityFramework.AdminConsole.Repositories; using Bit.Infrastructure.EntityFramework.Auth.Repositories; using Bit.Infrastructure.EntityFramework.Billing.Repositories; @@ -87,6 +88,7 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index a0bee13f2e..3255d1044b 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -110,6 +110,7 @@ public static class ServiceCollectionExtensions public static void AddBaseServices(this IServiceCollection services, IGlobalSettings globalSettings) { services.AddScoped(); + services.AddSingleton(); services.AddUserServices(globalSettings); services.AddTrialInitiationServices(); services.AddOrganizationServices(globalSettings); @@ -118,7 +119,6 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); - services.AddScoped(); services.AddSingleton(); services.AddScoped(); services.AddScoped(); diff --git a/test/Identity.Test/Controllers/AccountsControllerTests.cs b/test/Identity.Test/Controllers/AccountsControllerTests.cs index c0fa51fb4f..4928d08316 100644 --- a/test/Identity.Test/Controllers/AccountsControllerTests.cs +++ b/test/Identity.Test/Controllers/AccountsControllerTests.cs @@ -97,6 +97,11 @@ public class AccountsControllerTests : IDisposable KdfIterations = AuthConstants.PBKDF2_ITERATIONS.Default }; _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(userKdfInfo); + var mockUser = new User + { + Email = "user@example.com" + }; + _userRepository.GetByEmailAsync(Arg.Any()).Returns(mockUser); var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); @@ -109,6 +114,11 @@ public class AccountsControllerTests : IDisposable { SetDefaultKdfHmacKey(null); _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(null)); + var mockUser = new User + { + Email = "user@example.com" + }; + _userRepository.GetByEmailAsync(Arg.Any()).Returns(mockUser); var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); @@ -125,6 +135,11 @@ public class AccountsControllerTests : IDisposable SetDefaultKdfHmacKey(defaultKey); _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(null)); + var mockUser = new User + { + Email = "user@example.com" + }; + _userRepository.GetByEmailAsync(Arg.Any()).Returns(mockUser); var fieldInfo = typeof(AccountsController).GetField("_defaultKdfResults", BindingFlags.NonPublic | BindingFlags.Static); if (fieldInfo == null)