1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-04 20:50:21 -05:00

[BEEEP] Lazy load the current user in the CurrentContext

This commit is contained in:
Jonas Hendrickx 2025-04-04 11:10:53 +02:00
parent 67d7d685a6
commit 0f98f3959a
No known key found for this signature in database
GPG Key ID: C4B27F601CE4317D
6 changed files with 36 additions and 31 deletions

View File

@ -2,6 +2,7 @@
using Bit.Api.Billing.Models.Responses; using Bit.Api.Billing.Models.Responses;
using Bit.Core.Billing.Models.Api.Requests.Accounts; using Bit.Core.Billing.Models.Api.Requests.Accounts;
using Bit.Core.Billing.Services; using Bit.Core.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Services; using Bit.Core.Services;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
@ -12,15 +13,15 @@ namespace Bit.Api.Billing.Controllers;
[Route("accounts/billing")] [Route("accounts/billing")]
[Authorize("Application")] [Authorize("Application")]
public class AccountsBillingController( public class AccountsBillingController(
ICurrentContext currentContext,
IPaymentService paymentService, IPaymentService paymentService,
IUserService userService,
IPaymentHistoryService paymentHistoryService) : Controller IPaymentHistoryService paymentHistoryService) : Controller
{ {
[HttpGet("history")] [HttpGet("history")]
[SelfHosted(NotSelfHostedOnly = true)] [SelfHosted(NotSelfHostedOnly = true)]
public async Task<BillingHistoryResponseModel> GetBillingHistoryAsync() public async Task<BillingHistoryResponseModel> GetBillingHistoryAsync()
{ {
var user = await userService.GetUserByPrincipalAsync(User); var user = await currentContext.UserAsync.Value;
if (user == null) if (user == null)
{ {
throw new UnauthorizedAccessException(); throw new UnauthorizedAccessException();
@ -34,7 +35,7 @@ public class AccountsBillingController(
[SelfHosted(NotSelfHostedOnly = true)] [SelfHosted(NotSelfHostedOnly = true)]
public async Task<BillingPaymentResponseModel> GetPaymentMethodAsync() public async Task<BillingPaymentResponseModel> GetPaymentMethodAsync()
{ {
var user = await userService.GetUserByPrincipalAsync(User); var user = await currentContext.UserAsync.Value;
if (user == null) if (user == null)
{ {
throw new UnauthorizedAccessException(); throw new UnauthorizedAccessException();
@ -47,7 +48,7 @@ public class AccountsBillingController(
[HttpGet("invoices")] [HttpGet("invoices")]
public async Task<IResult> GetInvoicesAsync([FromQuery] string? status = null, [FromQuery] string? startAfter = null) public async Task<IResult> GetInvoicesAsync([FromQuery] string? status = null, [FromQuery] string? startAfter = null)
{ {
var user = await userService.GetUserByPrincipalAsync(User); var user = await currentContext.UserAsync.Value;
if (user == null) if (user == null)
{ {
throw new UnauthorizedAccessException(); throw new UnauthorizedAccessException();
@ -65,7 +66,7 @@ public class AccountsBillingController(
[HttpGet("transactions")] [HttpGet("transactions")]
public async Task<IResult> GetTransactionsAsync([FromQuery] DateTime? startAfter = null) public async Task<IResult> GetTransactionsAsync([FromQuery] DateTime? startAfter = null)
{ {
var user = await userService.GetUserByPrincipalAsync(User); var user = await currentContext.UserAsync.Value;
if (user == null) if (user == null)
{ {
throw new UnauthorizedAccessException(); throw new UnauthorizedAccessException();
@ -82,7 +83,7 @@ public class AccountsBillingController(
[HttpPost("preview-invoice")] [HttpPost("preview-invoice")]
public async Task<IResult> PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model) public async Task<IResult> PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model)
{ {
var user = await userService.GetUserByPrincipalAsync(User); var user = await currentContext.UserAsync.Value;
if (user == null) if (user == null)
{ {
throw new UnauthorizedAccessException(); throw new UnauthorizedAccessException();

View File

@ -44,31 +44,30 @@ public class UserStore :
public async Task<User> FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) public async Task<User> FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken))
{ {
if (_currentContext?.User != null && _currentContext.User.Email == normalizedEmail) var currentUser = await _currentContext.UserAsync.Value;
if (currentUser != null && currentUser.Email == normalizedEmail)
{ {
return _currentContext.User; return currentUser;
} }
_currentContext.User = await _userRepository.GetByEmailAsync(normalizedEmail); return await _userRepository.GetByEmailAsync(normalizedEmail);
return _currentContext.User;
} }
public async Task<User> FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken)) public async Task<User> FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken))
{ {
if (_currentContext?.User != null && var currentUser = await _currentContext.UserAsync.Value;
string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) if (currentUser != null &&
string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
{ {
return _currentContext.User; return currentUser;
} }
Guid userIdGuid; if (!Guid.TryParse(userId, out var userIdGuid))
if (!Guid.TryParse(userId, out userIdGuid))
{ {
return null; return null;
} }
_currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); return await _userRepository.GetByIdAsync(userIdGuid);
return _currentContext.User;
} }
public async Task<User> FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken)) public async Task<User> FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken))

View File

@ -19,6 +19,8 @@ public class CurrentContext : ICurrentContext
{ {
private readonly IProviderOrganizationRepository _providerOrganizationRepository; private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IProviderUserRepository _providerUserRepository; private readonly IProviderUserRepository _providerUserRepository;
private readonly IUserRepository _userRepository;
private bool _builtHttpContext; private bool _builtHttpContext;
private bool _builtClaimsPrincipal; private bool _builtClaimsPrincipal;
private IEnumerable<ProviderOrganizationProviderDetails> _providerOrganizationProviderDetails; private IEnumerable<ProviderOrganizationProviderDetails> _providerOrganizationProviderDetails;
@ -26,7 +28,7 @@ public class CurrentContext : ICurrentContext
public virtual HttpContext HttpContext { get; set; } public virtual HttpContext HttpContext { get; set; }
public virtual Guid? UserId { get; set; } public virtual Guid? UserId { get; set; }
public virtual User User { get; set; } public virtual Lazy<Task<User>> UserAsync { get; private set; }
public virtual string DeviceIdentifier { get; set; } public virtual string DeviceIdentifier { get; set; }
public virtual DeviceType? DeviceType { get; set; } public virtual DeviceType? DeviceType { get; set; }
public virtual string IpAddress { get; set; } public virtual string IpAddress { get; set; }
@ -47,10 +49,12 @@ public class CurrentContext : ICurrentContext
public CurrentContext( public CurrentContext(
IProviderOrganizationRepository providerOrganizationRepository, IProviderOrganizationRepository providerOrganizationRepository,
IProviderUserRepository providerUserRepository) IProviderUserRepository providerUserRepository,
IUserRepository userRepository)
{ {
_providerOrganizationRepository = providerOrganizationRepository; _providerOrganizationRepository = providerOrganizationRepository;
_providerUserRepository = providerUserRepository; _providerUserRepository = providerUserRepository;
_userRepository = userRepository;
} }
public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings) public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings)
@ -138,6 +142,7 @@ public class CurrentContext : ICurrentContext
if (Guid.TryParse(subject, out var subIdGuid)) if (Guid.TryParse(subject, out var subIdGuid))
{ {
UserId = subIdGuid; UserId = subIdGuid;
UserAsync = new Lazy<Task<User>>(() => _userRepository.GetByIdAsync(UserId.Value));
} }
ClientId = GetClaimValue(claimsDict, "client_id"); ClientId = GetClaimValue(claimsDict, "client_id");

View File

@ -16,7 +16,7 @@ public interface ICurrentContext
{ {
HttpContext HttpContext { get; set; } HttpContext HttpContext { get; set; }
Guid? UserId { get; set; } Guid? UserId { get; set; }
User User { get; set; } Lazy<Task<User>> UserAsync { get; }
string DeviceIdentifier { get; set; } string DeviceIdentifier { get; set; }
DeviceType? DeviceType { get; set; } DeviceType? DeviceType { get; set; }
string IpAddress { get; set; } string IpAddress { get; set; }

View File

@ -173,10 +173,11 @@ public class UserService : UserManager<User>, IUserService, IDisposable
public async Task<User> GetUserByIdAsync(string userId) public async Task<User> GetUserByIdAsync(string userId)
{ {
if (_currentContext?.User != null && var currentUser = await _currentContext.UserAsync.Value;
string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) if (currentUser != null &&
string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
{ {
return _currentContext.User; return currentUser;
} }
if (!Guid.TryParse(userId, out var userIdGuid)) if (!Guid.TryParse(userId, out var userIdGuid))
@ -184,19 +185,18 @@ public class UserService : UserManager<User>, IUserService, IDisposable
return null; return null;
} }
_currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); return await _userRepository.GetByIdAsync(userIdGuid);
return _currentContext.User;
} }
public async Task<User> GetUserByIdAsync(Guid userId) public async Task<User> GetUserByIdAsync(Guid userId)
{ {
if (_currentContext?.User != null && _currentContext.User.Id == userId) var currentUser = await _currentContext.UserAsync.Value;
if (currentUser != null && currentUser.Id == userId)
{ {
return _currentContext.User; return currentUser;
} }
_currentContext.User = await _userRepository.GetByIdAsync(userId); return await _userRepository.GetByIdAsync(userId);
return _currentContext.User;
} }
public async Task<User> GetUserByPrincipalAsync(ClaimsPrincipal principal) public async Task<User> GetUserByPrincipalAsync(ClaimsPrincipal principal)

View File

@ -20,7 +20,7 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
public override async Task OnConnectedAsync() public override async Task OnConnectedAsync()
{ {
var currentContext = new CurrentContext(null, null); var currentContext = new CurrentContext(null, null, null);
await currentContext.BuildAsync(Context.User, _globalSettings); await currentContext.BuildAsync(Context.User, _globalSettings);
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType); var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
@ -57,7 +57,7 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
public override async Task OnDisconnectedAsync(Exception exception) public override async Task OnDisconnectedAsync(Exception exception)
{ {
var currentContext = new CurrentContext(null, null); var currentContext = new CurrentContext(null, null, null);
await currentContext.BuildAsync(Context.User, _globalSettings); await currentContext.BuildAsync(Context.User, _globalSettings);
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType); var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);