1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-04 20:50:21 -05:00

[BEEEP] Lazy load the current user in the CurrentContext

This commit is contained in:
Jonas Hendrickx 2025-04-04 11:10:53 +02:00
parent 67d7d685a6
commit 0f98f3959a
No known key found for this signature in database
GPG Key ID: C4B27F601CE4317D
6 changed files with 36 additions and 31 deletions

View File

@ -2,6 +2,7 @@
using Bit.Api.Billing.Models.Responses;
using Bit.Core.Billing.Models.Api.Requests.Accounts;
using Bit.Core.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
@ -12,15 +13,15 @@ namespace Bit.Api.Billing.Controllers;
[Route("accounts/billing")]
[Authorize("Application")]
public class AccountsBillingController(
ICurrentContext currentContext,
IPaymentService paymentService,
IUserService userService,
IPaymentHistoryService paymentHistoryService) : Controller
{
[HttpGet("history")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<BillingHistoryResponseModel> GetBillingHistoryAsync()
{
var user = await userService.GetUserByPrincipalAsync(User);
var user = await currentContext.UserAsync.Value;
if (user == null)
{
throw new UnauthorizedAccessException();
@ -34,7 +35,7 @@ public class AccountsBillingController(
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<BillingPaymentResponseModel> GetPaymentMethodAsync()
{
var user = await userService.GetUserByPrincipalAsync(User);
var user = await currentContext.UserAsync.Value;
if (user == null)
{
throw new UnauthorizedAccessException();
@ -47,7 +48,7 @@ public class AccountsBillingController(
[HttpGet("invoices")]
public async Task<IResult> GetInvoicesAsync([FromQuery] string? status = null, [FromQuery] string? startAfter = null)
{
var user = await userService.GetUserByPrincipalAsync(User);
var user = await currentContext.UserAsync.Value;
if (user == null)
{
throw new UnauthorizedAccessException();
@ -65,7 +66,7 @@ public class AccountsBillingController(
[HttpGet("transactions")]
public async Task<IResult> GetTransactionsAsync([FromQuery] DateTime? startAfter = null)
{
var user = await userService.GetUserByPrincipalAsync(User);
var user = await currentContext.UserAsync.Value;
if (user == null)
{
throw new UnauthorizedAccessException();
@ -82,7 +83,7 @@ public class AccountsBillingController(
[HttpPost("preview-invoice")]
public async Task<IResult> PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model)
{
var user = await userService.GetUserByPrincipalAsync(User);
var user = await currentContext.UserAsync.Value;
if (user == null)
{
throw new UnauthorizedAccessException();

View File

@ -44,31 +44,30 @@ public class UserStore :
public async Task<User> FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken))
{
if (_currentContext?.User != null && _currentContext.User.Email == normalizedEmail)
var currentUser = await _currentContext.UserAsync.Value;
if (currentUser != null && currentUser.Email == normalizedEmail)
{
return _currentContext.User;
return currentUser;
}
_currentContext.User = await _userRepository.GetByEmailAsync(normalizedEmail);
return _currentContext.User;
return await _userRepository.GetByEmailAsync(normalizedEmail);
}
public async Task<User> FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken))
{
if (_currentContext?.User != null &&
string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
var currentUser = await _currentContext.UserAsync.Value;
if (currentUser != null &&
string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
{
return _currentContext.User;
return currentUser;
}
Guid userIdGuid;
if (!Guid.TryParse(userId, out userIdGuid))
if (!Guid.TryParse(userId, out var userIdGuid))
{
return null;
}
_currentContext.User = await _userRepository.GetByIdAsync(userIdGuid);
return _currentContext.User;
return await _userRepository.GetByIdAsync(userIdGuid);
}
public async Task<User> FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken))

View File

@ -19,6 +19,8 @@ public class CurrentContext : ICurrentContext
{
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IProviderUserRepository _providerUserRepository;
private readonly IUserRepository _userRepository;
private bool _builtHttpContext;
private bool _builtClaimsPrincipal;
private IEnumerable<ProviderOrganizationProviderDetails> _providerOrganizationProviderDetails;
@ -26,7 +28,7 @@ public class CurrentContext : ICurrentContext
public virtual HttpContext HttpContext { get; set; }
public virtual Guid? UserId { get; set; }
public virtual User User { get; set; }
public virtual Lazy<Task<User>> UserAsync { get; private set; }
public virtual string DeviceIdentifier { get; set; }
public virtual DeviceType? DeviceType { get; set; }
public virtual string IpAddress { get; set; }
@ -47,10 +49,12 @@ public class CurrentContext : ICurrentContext
public CurrentContext(
IProviderOrganizationRepository providerOrganizationRepository,
IProviderUserRepository providerUserRepository)
IProviderUserRepository providerUserRepository,
IUserRepository userRepository)
{
_providerOrganizationRepository = providerOrganizationRepository;
_providerUserRepository = providerUserRepository;
_userRepository = userRepository;
}
public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings)
@ -138,6 +142,7 @@ public class CurrentContext : ICurrentContext
if (Guid.TryParse(subject, out var subIdGuid))
{
UserId = subIdGuid;
UserAsync = new Lazy<Task<User>>(() => _userRepository.GetByIdAsync(UserId.Value));
}
ClientId = GetClaimValue(claimsDict, "client_id");

View File

@ -16,7 +16,7 @@ public interface ICurrentContext
{
HttpContext HttpContext { get; set; }
Guid? UserId { get; set; }
User User { get; set; }
Lazy<Task<User>> UserAsync { get; }
string DeviceIdentifier { get; set; }
DeviceType? DeviceType { get; set; }
string IpAddress { get; set; }

View File

@ -173,10 +173,11 @@ public class UserService : UserManager<User>, IUserService, IDisposable
public async Task<User> GetUserByIdAsync(string userId)
{
if (_currentContext?.User != null &&
string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
var currentUser = await _currentContext.UserAsync.Value;
if (currentUser != null &&
string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
{
return _currentContext.User;
return currentUser;
}
if (!Guid.TryParse(userId, out var userIdGuid))
@ -184,19 +185,18 @@ public class UserService : UserManager<User>, IUserService, IDisposable
return null;
}
_currentContext.User = await _userRepository.GetByIdAsync(userIdGuid);
return _currentContext.User;
return await _userRepository.GetByIdAsync(userIdGuid);
}
public async Task<User> GetUserByIdAsync(Guid userId)
{
if (_currentContext?.User != null && _currentContext.User.Id == userId)
var currentUser = await _currentContext.UserAsync.Value;
if (currentUser != null && currentUser.Id == userId)
{
return _currentContext.User;
return currentUser;
}
_currentContext.User = await _userRepository.GetByIdAsync(userId);
return _currentContext.User;
return await _userRepository.GetByIdAsync(userId);
}
public async Task<User> GetUserByPrincipalAsync(ClaimsPrincipal principal)

View File

@ -20,7 +20,7 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
public override async Task OnConnectedAsync()
{
var currentContext = new CurrentContext(null, null);
var currentContext = new CurrentContext(null, null, null);
await currentContext.BuildAsync(Context.User, _globalSettings);
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
@ -57,7 +57,7 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
public override async Task OnDisconnectedAsync(Exception exception)
{
var currentContext = new CurrentContext(null, null);
var currentContext = new CurrentContext(null, null, null);
await currentContext.BuildAsync(Context.User, _globalSettings);
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);