mirror of
https://github.com/bitwarden/server.git
synced 2025-04-04 12:40:22 -05:00
[BEEEP] Lazy load the current user in the CurrentContext
This commit is contained in:
parent
67d7d685a6
commit
0f98f3959a
@ -2,6 +2,7 @@
|
||||
using Bit.Api.Billing.Models.Responses;
|
||||
using Bit.Core.Billing.Models.Api.Requests.Accounts;
|
||||
using Bit.Core.Billing.Services;
|
||||
using Bit.Core.Context;
|
||||
using Bit.Core.Services;
|
||||
using Bit.Core.Utilities;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
@ -12,15 +13,15 @@ namespace Bit.Api.Billing.Controllers;
|
||||
[Route("accounts/billing")]
|
||||
[Authorize("Application")]
|
||||
public class AccountsBillingController(
|
||||
ICurrentContext currentContext,
|
||||
IPaymentService paymentService,
|
||||
IUserService userService,
|
||||
IPaymentHistoryService paymentHistoryService) : Controller
|
||||
{
|
||||
[HttpGet("history")]
|
||||
[SelfHosted(NotSelfHostedOnly = true)]
|
||||
public async Task<BillingHistoryResponseModel> GetBillingHistoryAsync()
|
||||
{
|
||||
var user = await userService.GetUserByPrincipalAsync(User);
|
||||
var user = await currentContext.UserAsync.Value;
|
||||
if (user == null)
|
||||
{
|
||||
throw new UnauthorizedAccessException();
|
||||
@ -34,7 +35,7 @@ public class AccountsBillingController(
|
||||
[SelfHosted(NotSelfHostedOnly = true)]
|
||||
public async Task<BillingPaymentResponseModel> GetPaymentMethodAsync()
|
||||
{
|
||||
var user = await userService.GetUserByPrincipalAsync(User);
|
||||
var user = await currentContext.UserAsync.Value;
|
||||
if (user == null)
|
||||
{
|
||||
throw new UnauthorizedAccessException();
|
||||
@ -47,7 +48,7 @@ public class AccountsBillingController(
|
||||
[HttpGet("invoices")]
|
||||
public async Task<IResult> GetInvoicesAsync([FromQuery] string? status = null, [FromQuery] string? startAfter = null)
|
||||
{
|
||||
var user = await userService.GetUserByPrincipalAsync(User);
|
||||
var user = await currentContext.UserAsync.Value;
|
||||
if (user == null)
|
||||
{
|
||||
throw new UnauthorizedAccessException();
|
||||
@ -65,7 +66,7 @@ public class AccountsBillingController(
|
||||
[HttpGet("transactions")]
|
||||
public async Task<IResult> GetTransactionsAsync([FromQuery] DateTime? startAfter = null)
|
||||
{
|
||||
var user = await userService.GetUserByPrincipalAsync(User);
|
||||
var user = await currentContext.UserAsync.Value;
|
||||
if (user == null)
|
||||
{
|
||||
throw new UnauthorizedAccessException();
|
||||
@ -82,7 +83,7 @@ public class AccountsBillingController(
|
||||
[HttpPost("preview-invoice")]
|
||||
public async Task<IResult> PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model)
|
||||
{
|
||||
var user = await userService.GetUserByPrincipalAsync(User);
|
||||
var user = await currentContext.UserAsync.Value;
|
||||
if (user == null)
|
||||
{
|
||||
throw new UnauthorizedAccessException();
|
||||
|
@ -44,31 +44,30 @@ public class UserStore :
|
||||
|
||||
public async Task<User> FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken))
|
||||
{
|
||||
if (_currentContext?.User != null && _currentContext.User.Email == normalizedEmail)
|
||||
var currentUser = await _currentContext.UserAsync.Value;
|
||||
if (currentUser != null && currentUser.Email == normalizedEmail)
|
||||
{
|
||||
return _currentContext.User;
|
||||
return currentUser;
|
||||
}
|
||||
|
||||
_currentContext.User = await _userRepository.GetByEmailAsync(normalizedEmail);
|
||||
return _currentContext.User;
|
||||
return await _userRepository.GetByEmailAsync(normalizedEmail);
|
||||
}
|
||||
|
||||
public async Task<User> FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken))
|
||||
{
|
||||
if (_currentContext?.User != null &&
|
||||
string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
|
||||
var currentUser = await _currentContext.UserAsync.Value;
|
||||
if (currentUser != null &&
|
||||
string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
|
||||
{
|
||||
return _currentContext.User;
|
||||
return currentUser;
|
||||
}
|
||||
|
||||
Guid userIdGuid;
|
||||
if (!Guid.TryParse(userId, out userIdGuid))
|
||||
if (!Guid.TryParse(userId, out var userIdGuid))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
_currentContext.User = await _userRepository.GetByIdAsync(userIdGuid);
|
||||
return _currentContext.User;
|
||||
return await _userRepository.GetByIdAsync(userIdGuid);
|
||||
}
|
||||
|
||||
public async Task<User> FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken))
|
||||
|
@ -19,6 +19,8 @@ public class CurrentContext : ICurrentContext
|
||||
{
|
||||
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
|
||||
private readonly IProviderUserRepository _providerUserRepository;
|
||||
private readonly IUserRepository _userRepository;
|
||||
|
||||
private bool _builtHttpContext;
|
||||
private bool _builtClaimsPrincipal;
|
||||
private IEnumerable<ProviderOrganizationProviderDetails> _providerOrganizationProviderDetails;
|
||||
@ -26,7 +28,7 @@ public class CurrentContext : ICurrentContext
|
||||
|
||||
public virtual HttpContext HttpContext { get; set; }
|
||||
public virtual Guid? UserId { get; set; }
|
||||
public virtual User User { get; set; }
|
||||
public virtual Lazy<Task<User>> UserAsync { get; private set; }
|
||||
public virtual string DeviceIdentifier { get; set; }
|
||||
public virtual DeviceType? DeviceType { get; set; }
|
||||
public virtual string IpAddress { get; set; }
|
||||
@ -47,10 +49,12 @@ public class CurrentContext : ICurrentContext
|
||||
|
||||
public CurrentContext(
|
||||
IProviderOrganizationRepository providerOrganizationRepository,
|
||||
IProviderUserRepository providerUserRepository)
|
||||
IProviderUserRepository providerUserRepository,
|
||||
IUserRepository userRepository)
|
||||
{
|
||||
_providerOrganizationRepository = providerOrganizationRepository;
|
||||
_providerUserRepository = providerUserRepository;
|
||||
_userRepository = userRepository;
|
||||
}
|
||||
|
||||
public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings)
|
||||
@ -138,6 +142,7 @@ public class CurrentContext : ICurrentContext
|
||||
if (Guid.TryParse(subject, out var subIdGuid))
|
||||
{
|
||||
UserId = subIdGuid;
|
||||
UserAsync = new Lazy<Task<User>>(() => _userRepository.GetByIdAsync(UserId.Value));
|
||||
}
|
||||
|
||||
ClientId = GetClaimValue(claimsDict, "client_id");
|
||||
|
@ -16,7 +16,7 @@ public interface ICurrentContext
|
||||
{
|
||||
HttpContext HttpContext { get; set; }
|
||||
Guid? UserId { get; set; }
|
||||
User User { get; set; }
|
||||
Lazy<Task<User>> UserAsync { get; }
|
||||
string DeviceIdentifier { get; set; }
|
||||
DeviceType? DeviceType { get; set; }
|
||||
string IpAddress { get; set; }
|
||||
|
@ -173,10 +173,11 @@ public class UserService : UserManager<User>, IUserService, IDisposable
|
||||
|
||||
public async Task<User> GetUserByIdAsync(string userId)
|
||||
{
|
||||
if (_currentContext?.User != null &&
|
||||
string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
|
||||
var currentUser = await _currentContext.UserAsync.Value;
|
||||
if (currentUser != null &&
|
||||
string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
|
||||
{
|
||||
return _currentContext.User;
|
||||
return currentUser;
|
||||
}
|
||||
|
||||
if (!Guid.TryParse(userId, out var userIdGuid))
|
||||
@ -184,19 +185,18 @@ public class UserService : UserManager<User>, IUserService, IDisposable
|
||||
return null;
|
||||
}
|
||||
|
||||
_currentContext.User = await _userRepository.GetByIdAsync(userIdGuid);
|
||||
return _currentContext.User;
|
||||
return await _userRepository.GetByIdAsync(userIdGuid);
|
||||
}
|
||||
|
||||
public async Task<User> GetUserByIdAsync(Guid userId)
|
||||
{
|
||||
if (_currentContext?.User != null && _currentContext.User.Id == userId)
|
||||
var currentUser = await _currentContext.UserAsync.Value;
|
||||
if (currentUser != null && currentUser.Id == userId)
|
||||
{
|
||||
return _currentContext.User;
|
||||
return currentUser;
|
||||
}
|
||||
|
||||
_currentContext.User = await _userRepository.GetByIdAsync(userId);
|
||||
return _currentContext.User;
|
||||
return await _userRepository.GetByIdAsync(userId);
|
||||
}
|
||||
|
||||
public async Task<User> GetUserByPrincipalAsync(ClaimsPrincipal principal)
|
||||
|
@ -20,7 +20,7 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
|
||||
|
||||
public override async Task OnConnectedAsync()
|
||||
{
|
||||
var currentContext = new CurrentContext(null, null);
|
||||
var currentContext = new CurrentContext(null, null, null);
|
||||
await currentContext.BuildAsync(Context.User, _globalSettings);
|
||||
|
||||
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
|
||||
@ -57,7 +57,7 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
|
||||
|
||||
public override async Task OnDisconnectedAsync(Exception exception)
|
||||
{
|
||||
var currentContext = new CurrentContext(null, null);
|
||||
var currentContext = new CurrentContext(null, null, null);
|
||||
await currentContext.BuildAsync(Context.User, _globalSettings);
|
||||
|
||||
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
|
||||
|
Loading…
x
Reference in New Issue
Block a user