diff --git a/src/Api/Billing/Controllers/AccountsBillingController.cs b/src/Api/Billing/Controllers/AccountsBillingController.cs index fcb89226e7..17418377d0 100644 --- a/src/Api/Billing/Controllers/AccountsBillingController.cs +++ b/src/Api/Billing/Controllers/AccountsBillingController.cs @@ -2,6 +2,7 @@ using Bit.Api.Billing.Models.Responses; using Bit.Core.Billing.Models.Api.Requests.Accounts; using Bit.Core.Billing.Services; +using Bit.Core.Context; using Bit.Core.Services; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; @@ -12,15 +13,15 @@ namespace Bit.Api.Billing.Controllers; [Route("accounts/billing")] [Authorize("Application")] public class AccountsBillingController( + ICurrentContext currentContext, IPaymentService paymentService, - IUserService userService, IPaymentHistoryService paymentHistoryService) : Controller { [HttpGet("history")] [SelfHosted(NotSelfHostedOnly = true)] public async Task GetBillingHistoryAsync() { - var user = await userService.GetUserByPrincipalAsync(User); + var user = await currentContext.UserAsync.Value; if (user == null) { throw new UnauthorizedAccessException(); @@ -34,7 +35,7 @@ public class AccountsBillingController( [SelfHosted(NotSelfHostedOnly = true)] public async Task GetPaymentMethodAsync() { - var user = await userService.GetUserByPrincipalAsync(User); + var user = await currentContext.UserAsync.Value; if (user == null) { throw new UnauthorizedAccessException(); @@ -47,7 +48,7 @@ public class AccountsBillingController( [HttpGet("invoices")] public async Task GetInvoicesAsync([FromQuery] string? status = null, [FromQuery] string? startAfter = null) { - var user = await userService.GetUserByPrincipalAsync(User); + var user = await currentContext.UserAsync.Value; if (user == null) { throw new UnauthorizedAccessException(); @@ -65,7 +66,7 @@ public class AccountsBillingController( [HttpGet("transactions")] public async Task GetTransactionsAsync([FromQuery] DateTime? startAfter = null) { - var user = await userService.GetUserByPrincipalAsync(User); + var user = await currentContext.UserAsync.Value; if (user == null) { throw new UnauthorizedAccessException(); @@ -82,7 +83,7 @@ public class AccountsBillingController( [HttpPost("preview-invoice")] public async Task PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model) { - var user = await userService.GetUserByPrincipalAsync(User); + var user = await currentContext.UserAsync.Value; if (user == null) { throw new UnauthorizedAccessException(); diff --git a/src/Core/Auth/Identity/UserStore.cs b/src/Core/Auth/Identity/UserStore.cs index 3716d75b6a..d487b34157 100644 --- a/src/Core/Auth/Identity/UserStore.cs +++ b/src/Core/Auth/Identity/UserStore.cs @@ -44,31 +44,30 @@ public class UserStore : public async Task FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) { - if (_currentContext?.User != null && _currentContext.User.Email == normalizedEmail) + var currentUser = await _currentContext.UserAsync.Value; + if (currentUser != null && currentUser.Email == normalizedEmail) { - return _currentContext.User; + return currentUser; } - _currentContext.User = await _userRepository.GetByEmailAsync(normalizedEmail); - return _currentContext.User; + return await _userRepository.GetByEmailAsync(normalizedEmail); } public async Task FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken)) { - if (_currentContext?.User != null && - string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) + var currentUser = await _currentContext.UserAsync.Value; + if (currentUser != null && + string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) { - return _currentContext.User; + return currentUser; } - Guid userIdGuid; - if (!Guid.TryParse(userId, out userIdGuid)) + if (!Guid.TryParse(userId, out var userIdGuid)) { return null; } - _currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); - return _currentContext.User; + return await _userRepository.GetByIdAsync(userIdGuid); } public async Task FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken)) diff --git a/src/Core/Context/CurrentContext.cs b/src/Core/Context/CurrentContext.cs index cbd90055b0..40a1bca33f 100644 --- a/src/Core/Context/CurrentContext.cs +++ b/src/Core/Context/CurrentContext.cs @@ -19,6 +19,8 @@ public class CurrentContext : ICurrentContext { private readonly IProviderOrganizationRepository _providerOrganizationRepository; private readonly IProviderUserRepository _providerUserRepository; + private readonly IUserRepository _userRepository; + private bool _builtHttpContext; private bool _builtClaimsPrincipal; private IEnumerable _providerOrganizationProviderDetails; @@ -26,7 +28,7 @@ public class CurrentContext : ICurrentContext public virtual HttpContext HttpContext { get; set; } public virtual Guid? UserId { get; set; } - public virtual User User { get; set; } + public virtual Lazy> UserAsync { get; private set; } public virtual string DeviceIdentifier { get; set; } public virtual DeviceType? DeviceType { get; set; } public virtual string IpAddress { get; set; } @@ -47,10 +49,12 @@ public class CurrentContext : ICurrentContext public CurrentContext( IProviderOrganizationRepository providerOrganizationRepository, - IProviderUserRepository providerUserRepository) + IProviderUserRepository providerUserRepository, + IUserRepository userRepository) { _providerOrganizationRepository = providerOrganizationRepository; _providerUserRepository = providerUserRepository; + _userRepository = userRepository; } public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings) @@ -138,6 +142,7 @@ public class CurrentContext : ICurrentContext if (Guid.TryParse(subject, out var subIdGuid)) { UserId = subIdGuid; + UserAsync = new Lazy>(() => _userRepository.GetByIdAsync(UserId.Value)); } ClientId = GetClaimValue(claimsDict, "client_id"); diff --git a/src/Core/Context/ICurrentContext.cs b/src/Core/Context/ICurrentContext.cs index 42843ce6d7..6cef6ab167 100644 --- a/src/Core/Context/ICurrentContext.cs +++ b/src/Core/Context/ICurrentContext.cs @@ -16,7 +16,7 @@ public interface ICurrentContext { HttpContext HttpContext { get; set; } Guid? UserId { get; set; } - User User { get; set; } + Lazy> UserAsync { get; } string DeviceIdentifier { get; set; } DeviceType? DeviceType { get; set; } string IpAddress { get; set; } diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index 5076c8282e..f10c6d4f90 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -173,10 +173,11 @@ public class UserService : UserManager, IUserService, IDisposable public async Task GetUserByIdAsync(string userId) { - if (_currentContext?.User != null && - string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) + var currentUser = await _currentContext.UserAsync.Value; + if (currentUser != null && + string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) { - return _currentContext.User; + return currentUser; } if (!Guid.TryParse(userId, out var userIdGuid)) @@ -184,19 +185,18 @@ public class UserService : UserManager, IUserService, IDisposable return null; } - _currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); - return _currentContext.User; + return await _userRepository.GetByIdAsync(userIdGuid); } public async Task GetUserByIdAsync(Guid userId) { - if (_currentContext?.User != null && _currentContext.User.Id == userId) + var currentUser = await _currentContext.UserAsync.Value; + if (currentUser != null && currentUser.Id == userId) { - return _currentContext.User; + return currentUser; } - _currentContext.User = await _userRepository.GetByIdAsync(userId); - return _currentContext.User; + return await _userRepository.GetByIdAsync(userId); } public async Task GetUserByPrincipalAsync(ClaimsPrincipal principal) diff --git a/src/Notifications/NotificationsHub.cs b/src/Notifications/NotificationsHub.cs index ed62dbbd66..88fb6d66f8 100644 --- a/src/Notifications/NotificationsHub.cs +++ b/src/Notifications/NotificationsHub.cs @@ -20,7 +20,7 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub public override async Task OnConnectedAsync() { - var currentContext = new CurrentContext(null, null); + var currentContext = new CurrentContext(null, null, null); await currentContext.BuildAsync(Context.User, _globalSettings); var clientType = DeviceTypes.ToClientType(currentContext.DeviceType); @@ -57,7 +57,7 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub public override async Task OnDisconnectedAsync(Exception exception) { - var currentContext = new CurrentContext(null, null); + var currentContext = new CurrentContext(null, null, null); await currentContext.BuildAsync(Context.User, _globalSettings); var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);