mirror of
https://github.com/bitwarden/server.git
synced 2025-04-04 20:50:21 -05:00
[BEEEP] Lazy load the current user in the CurrentContext
This commit is contained in:
parent
67d7d685a6
commit
0f98f3959a
@ -2,6 +2,7 @@
|
|||||||
using Bit.Api.Billing.Models.Responses;
|
using Bit.Api.Billing.Models.Responses;
|
||||||
using Bit.Core.Billing.Models.Api.Requests.Accounts;
|
using Bit.Core.Billing.Models.Api.Requests.Accounts;
|
||||||
using Bit.Core.Billing.Services;
|
using Bit.Core.Billing.Services;
|
||||||
|
using Bit.Core.Context;
|
||||||
using Bit.Core.Services;
|
using Bit.Core.Services;
|
||||||
using Bit.Core.Utilities;
|
using Bit.Core.Utilities;
|
||||||
using Microsoft.AspNetCore.Authorization;
|
using Microsoft.AspNetCore.Authorization;
|
||||||
@ -12,15 +13,15 @@ namespace Bit.Api.Billing.Controllers;
|
|||||||
[Route("accounts/billing")]
|
[Route("accounts/billing")]
|
||||||
[Authorize("Application")]
|
[Authorize("Application")]
|
||||||
public class AccountsBillingController(
|
public class AccountsBillingController(
|
||||||
|
ICurrentContext currentContext,
|
||||||
IPaymentService paymentService,
|
IPaymentService paymentService,
|
||||||
IUserService userService,
|
|
||||||
IPaymentHistoryService paymentHistoryService) : Controller
|
IPaymentHistoryService paymentHistoryService) : Controller
|
||||||
{
|
{
|
||||||
[HttpGet("history")]
|
[HttpGet("history")]
|
||||||
[SelfHosted(NotSelfHostedOnly = true)]
|
[SelfHosted(NotSelfHostedOnly = true)]
|
||||||
public async Task<BillingHistoryResponseModel> GetBillingHistoryAsync()
|
public async Task<BillingHistoryResponseModel> GetBillingHistoryAsync()
|
||||||
{
|
{
|
||||||
var user = await userService.GetUserByPrincipalAsync(User);
|
var user = await currentContext.UserAsync.Value;
|
||||||
if (user == null)
|
if (user == null)
|
||||||
{
|
{
|
||||||
throw new UnauthorizedAccessException();
|
throw new UnauthorizedAccessException();
|
||||||
@ -34,7 +35,7 @@ public class AccountsBillingController(
|
|||||||
[SelfHosted(NotSelfHostedOnly = true)]
|
[SelfHosted(NotSelfHostedOnly = true)]
|
||||||
public async Task<BillingPaymentResponseModel> GetPaymentMethodAsync()
|
public async Task<BillingPaymentResponseModel> GetPaymentMethodAsync()
|
||||||
{
|
{
|
||||||
var user = await userService.GetUserByPrincipalAsync(User);
|
var user = await currentContext.UserAsync.Value;
|
||||||
if (user == null)
|
if (user == null)
|
||||||
{
|
{
|
||||||
throw new UnauthorizedAccessException();
|
throw new UnauthorizedAccessException();
|
||||||
@ -47,7 +48,7 @@ public class AccountsBillingController(
|
|||||||
[HttpGet("invoices")]
|
[HttpGet("invoices")]
|
||||||
public async Task<IResult> GetInvoicesAsync([FromQuery] string? status = null, [FromQuery] string? startAfter = null)
|
public async Task<IResult> GetInvoicesAsync([FromQuery] string? status = null, [FromQuery] string? startAfter = null)
|
||||||
{
|
{
|
||||||
var user = await userService.GetUserByPrincipalAsync(User);
|
var user = await currentContext.UserAsync.Value;
|
||||||
if (user == null)
|
if (user == null)
|
||||||
{
|
{
|
||||||
throw new UnauthorizedAccessException();
|
throw new UnauthorizedAccessException();
|
||||||
@ -65,7 +66,7 @@ public class AccountsBillingController(
|
|||||||
[HttpGet("transactions")]
|
[HttpGet("transactions")]
|
||||||
public async Task<IResult> GetTransactionsAsync([FromQuery] DateTime? startAfter = null)
|
public async Task<IResult> GetTransactionsAsync([FromQuery] DateTime? startAfter = null)
|
||||||
{
|
{
|
||||||
var user = await userService.GetUserByPrincipalAsync(User);
|
var user = await currentContext.UserAsync.Value;
|
||||||
if (user == null)
|
if (user == null)
|
||||||
{
|
{
|
||||||
throw new UnauthorizedAccessException();
|
throw new UnauthorizedAccessException();
|
||||||
@ -82,7 +83,7 @@ public class AccountsBillingController(
|
|||||||
[HttpPost("preview-invoice")]
|
[HttpPost("preview-invoice")]
|
||||||
public async Task<IResult> PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model)
|
public async Task<IResult> PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model)
|
||||||
{
|
{
|
||||||
var user = await userService.GetUserByPrincipalAsync(User);
|
var user = await currentContext.UserAsync.Value;
|
||||||
if (user == null)
|
if (user == null)
|
||||||
{
|
{
|
||||||
throw new UnauthorizedAccessException();
|
throw new UnauthorizedAccessException();
|
||||||
|
@ -44,31 +44,30 @@ public class UserStore :
|
|||||||
|
|
||||||
public async Task<User> FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken))
|
public async Task<User> FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken))
|
||||||
{
|
{
|
||||||
if (_currentContext?.User != null && _currentContext.User.Email == normalizedEmail)
|
var currentUser = await _currentContext.UserAsync.Value;
|
||||||
|
if (currentUser != null && currentUser.Email == normalizedEmail)
|
||||||
{
|
{
|
||||||
return _currentContext.User;
|
return currentUser;
|
||||||
}
|
}
|
||||||
|
|
||||||
_currentContext.User = await _userRepository.GetByEmailAsync(normalizedEmail);
|
return await _userRepository.GetByEmailAsync(normalizedEmail);
|
||||||
return _currentContext.User;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<User> FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken))
|
public async Task<User> FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken))
|
||||||
{
|
{
|
||||||
if (_currentContext?.User != null &&
|
var currentUser = await _currentContext.UserAsync.Value;
|
||||||
string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
|
if (currentUser != null &&
|
||||||
|
string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
|
||||||
{
|
{
|
||||||
return _currentContext.User;
|
return currentUser;
|
||||||
}
|
}
|
||||||
|
|
||||||
Guid userIdGuid;
|
if (!Guid.TryParse(userId, out var userIdGuid))
|
||||||
if (!Guid.TryParse(userId, out userIdGuid))
|
|
||||||
{
|
{
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
_currentContext.User = await _userRepository.GetByIdAsync(userIdGuid);
|
return await _userRepository.GetByIdAsync(userIdGuid);
|
||||||
return _currentContext.User;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<User> FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken))
|
public async Task<User> FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken))
|
||||||
|
@ -19,6 +19,8 @@ public class CurrentContext : ICurrentContext
|
|||||||
{
|
{
|
||||||
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
|
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
|
||||||
private readonly IProviderUserRepository _providerUserRepository;
|
private readonly IProviderUserRepository _providerUserRepository;
|
||||||
|
private readonly IUserRepository _userRepository;
|
||||||
|
|
||||||
private bool _builtHttpContext;
|
private bool _builtHttpContext;
|
||||||
private bool _builtClaimsPrincipal;
|
private bool _builtClaimsPrincipal;
|
||||||
private IEnumerable<ProviderOrganizationProviderDetails> _providerOrganizationProviderDetails;
|
private IEnumerable<ProviderOrganizationProviderDetails> _providerOrganizationProviderDetails;
|
||||||
@ -26,7 +28,7 @@ public class CurrentContext : ICurrentContext
|
|||||||
|
|
||||||
public virtual HttpContext HttpContext { get; set; }
|
public virtual HttpContext HttpContext { get; set; }
|
||||||
public virtual Guid? UserId { get; set; }
|
public virtual Guid? UserId { get; set; }
|
||||||
public virtual User User { get; set; }
|
public virtual Lazy<Task<User>> UserAsync { get; private set; }
|
||||||
public virtual string DeviceIdentifier { get; set; }
|
public virtual string DeviceIdentifier { get; set; }
|
||||||
public virtual DeviceType? DeviceType { get; set; }
|
public virtual DeviceType? DeviceType { get; set; }
|
||||||
public virtual string IpAddress { get; set; }
|
public virtual string IpAddress { get; set; }
|
||||||
@ -47,10 +49,12 @@ public class CurrentContext : ICurrentContext
|
|||||||
|
|
||||||
public CurrentContext(
|
public CurrentContext(
|
||||||
IProviderOrganizationRepository providerOrganizationRepository,
|
IProviderOrganizationRepository providerOrganizationRepository,
|
||||||
IProviderUserRepository providerUserRepository)
|
IProviderUserRepository providerUserRepository,
|
||||||
|
IUserRepository userRepository)
|
||||||
{
|
{
|
||||||
_providerOrganizationRepository = providerOrganizationRepository;
|
_providerOrganizationRepository = providerOrganizationRepository;
|
||||||
_providerUserRepository = providerUserRepository;
|
_providerUserRepository = providerUserRepository;
|
||||||
|
_userRepository = userRepository;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings)
|
public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings)
|
||||||
@ -138,6 +142,7 @@ public class CurrentContext : ICurrentContext
|
|||||||
if (Guid.TryParse(subject, out var subIdGuid))
|
if (Guid.TryParse(subject, out var subIdGuid))
|
||||||
{
|
{
|
||||||
UserId = subIdGuid;
|
UserId = subIdGuid;
|
||||||
|
UserAsync = new Lazy<Task<User>>(() => _userRepository.GetByIdAsync(UserId.Value));
|
||||||
}
|
}
|
||||||
|
|
||||||
ClientId = GetClaimValue(claimsDict, "client_id");
|
ClientId = GetClaimValue(claimsDict, "client_id");
|
||||||
|
@ -16,7 +16,7 @@ public interface ICurrentContext
|
|||||||
{
|
{
|
||||||
HttpContext HttpContext { get; set; }
|
HttpContext HttpContext { get; set; }
|
||||||
Guid? UserId { get; set; }
|
Guid? UserId { get; set; }
|
||||||
User User { get; set; }
|
Lazy<Task<User>> UserAsync { get; }
|
||||||
string DeviceIdentifier { get; set; }
|
string DeviceIdentifier { get; set; }
|
||||||
DeviceType? DeviceType { get; set; }
|
DeviceType? DeviceType { get; set; }
|
||||||
string IpAddress { get; set; }
|
string IpAddress { get; set; }
|
||||||
|
@ -173,10 +173,11 @@ public class UserService : UserManager<User>, IUserService, IDisposable
|
|||||||
|
|
||||||
public async Task<User> GetUserByIdAsync(string userId)
|
public async Task<User> GetUserByIdAsync(string userId)
|
||||||
{
|
{
|
||||||
if (_currentContext?.User != null &&
|
var currentUser = await _currentContext.UserAsync.Value;
|
||||||
string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
|
if (currentUser != null &&
|
||||||
|
string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
|
||||||
{
|
{
|
||||||
return _currentContext.User;
|
return currentUser;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!Guid.TryParse(userId, out var userIdGuid))
|
if (!Guid.TryParse(userId, out var userIdGuid))
|
||||||
@ -184,19 +185,18 @@ public class UserService : UserManager<User>, IUserService, IDisposable
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
_currentContext.User = await _userRepository.GetByIdAsync(userIdGuid);
|
return await _userRepository.GetByIdAsync(userIdGuid);
|
||||||
return _currentContext.User;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<User> GetUserByIdAsync(Guid userId)
|
public async Task<User> GetUserByIdAsync(Guid userId)
|
||||||
{
|
{
|
||||||
if (_currentContext?.User != null && _currentContext.User.Id == userId)
|
var currentUser = await _currentContext.UserAsync.Value;
|
||||||
|
if (currentUser != null && currentUser.Id == userId)
|
||||||
{
|
{
|
||||||
return _currentContext.User;
|
return currentUser;
|
||||||
}
|
}
|
||||||
|
|
||||||
_currentContext.User = await _userRepository.GetByIdAsync(userId);
|
return await _userRepository.GetByIdAsync(userId);
|
||||||
return _currentContext.User;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<User> GetUserByPrincipalAsync(ClaimsPrincipal principal)
|
public async Task<User> GetUserByPrincipalAsync(ClaimsPrincipal principal)
|
||||||
|
@ -20,7 +20,7 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
|
|||||||
|
|
||||||
public override async Task OnConnectedAsync()
|
public override async Task OnConnectedAsync()
|
||||||
{
|
{
|
||||||
var currentContext = new CurrentContext(null, null);
|
var currentContext = new CurrentContext(null, null, null);
|
||||||
await currentContext.BuildAsync(Context.User, _globalSettings);
|
await currentContext.BuildAsync(Context.User, _globalSettings);
|
||||||
|
|
||||||
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
|
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
|
||||||
@ -57,7 +57,7 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
|
|||||||
|
|
||||||
public override async Task OnDisconnectedAsync(Exception exception)
|
public override async Task OnDisconnectedAsync(Exception exception)
|
||||||
{
|
{
|
||||||
var currentContext = new CurrentContext(null, null);
|
var currentContext = new CurrentContext(null, null, null);
|
||||||
await currentContext.BuildAsync(Context.User, _globalSettings);
|
await currentContext.BuildAsync(Context.User, _globalSettings);
|
||||||
|
|
||||||
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
|
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user