1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-30 01:02:21 -05:00
Rui Tomé 8d69bb0aaa
[AC-1698] Check if a user has 2FA enabled more efficiently (#4524)
* feat: Add stored procedure for reading organization user details with premium access by organization ID

The code changes include:
- Addition of a new stored procedure [dbo].[OrganizationUserUserDetailsWithPremiumAccess_ReadByOrganizationId] to read organization user details with premium access by organization ID
- Modification of the IUserService interface to include an optional parameter for checking two-factor authentication with premium access
- Modification of the UserService class to handle the new optional parameter in the TwoFactorIsEnabledAsync method
- Addition of a new method GetManyDetailsWithPremiumAccessByOrganizationAsync in the IOrganizationUserRepository interface to retrieve organization user details with premium access by organization ID
- Addition of a new view [dbo].[OrganizationUserUserDetailsWithPremiumAccessView] to retrieve organization user details with premium access

* Add IUserRepository.SearchDetailsAsync that includes the field HasPremiumAccess

* Check the feature flag on Admin.UsersController to see if the optimization runs

* Modify PolicyService to run query optimization if the feature flag is enabled

* Refactor the parameter check on UserService.TwoFactorIsEnabledAsync

* Run query optimization on public MembersController if feature flag is enabled

* Restore refactor

* Reverted change used for development

* Add unit tests for OrganizationService.RestoreUser

* Separate new CheckPoliciesBeforeRestoreAsync optimization into new method

* Add more unit tests

* Apply refactor to bulk restore

* Add GetManyDetailsAsync method to IUserRepository. Add ConfirmUsersAsync_vNext method to IOrganizationService

* Add unit tests for ConfirmUser_vNext

* Refactor the optimization to use the new TwoFactorIsEnabledAsync method instead of changing the existing one

* Removed unused sql scripts and added migration script

* Remove unnecessary view

* chore: Remove unused SearchDetailsAsync method from IUserRepository and UserRepository

* refactor: Use UserDetails constructor in UserRepository

* Add summary to IUserRepository.GetManyDetailsAsync

* Add summary descriptions to IUserService.TwoFactorIsEnabledAsync

* Remove obsolete annotation from IUserRepository.UpdateUserKeyAndEncryptedDataAsync

* refactor: Rename UserDetails to UserWithCalculatedPremium across the codebase

* Extract IUserService.TwoFactorIsEnabledAsync into a new TwoFactorIsEnabledQuery class

* Add unit tests for TwoFactorIsEnabledQuery

* Update TwoFactorIsEnabledQueryTests to include additional provider types

* Refactor TwoFactorIsEnabledQuery

* Refactor TwoFactorIsEnabledQuery and update tests

* refactor: Update TwoFactorIsEnabledQueryTests to include test for null TwoFactorProviders

* refactor: Improve TwoFactorIsEnabledQuery and update tests

* refactor: Improve TwoFactorIsEnabledQuery and update tests

* Remove empty <returns> from summary

* Update User_ReadByIdsWithCalculatedPremium stored procedure to accept JSON array of IDs
2024-08-08 15:43:45 +01:00

267 lines
10 KiB
C#

using AutoMapper;
using Bit.Core.Auth.UserFeatures.UserKey;
using Bit.Core.Repositories;
using Bit.Infrastructure.EntityFramework.Models;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using DataModel = Bit.Core.Models.Data;
#nullable enable
namespace Bit.Infrastructure.EntityFramework.Repositories;
public class UserRepository : Repository<Core.Entities.User, User, Guid>, IUserRepository
{
public UserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper)
: base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Users)
{ }
public async Task<Core.Entities.User?> GetByEmailAsync(string email)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var entity = await GetDbSet(dbContext).FirstOrDefaultAsync(e => e.Email == email);
return Mapper.Map<Core.Entities.User>(entity);
}
}
public async Task<IEnumerable<Core.Entities.User>> GetManyByEmailsAsync(IEnumerable<string> emails)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var users = await GetDbSet(dbContext)
.Where(u => emails.Contains(u.Email))
.ToListAsync();
return Mapper.Map<List<Core.Entities.User>>(users);
}
}
public async Task<DataModel.UserKdfInformation?> GetKdfInformationByEmailAsync(string email)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
return await GetDbSet(dbContext).Where(e => e.Email == email)
.Select(e => new DataModel.UserKdfInformation
{
Kdf = e.Kdf,
KdfIterations = e.KdfIterations,
KdfMemory = e.KdfMemory,
KdfParallelism = e.KdfParallelism
}).SingleOrDefaultAsync();
}
}
public async Task<ICollection<Core.Entities.User>> SearchAsync(string email, int skip, int take)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
List<User> users;
if (dbContext.Database.IsNpgsql())
{
users = await GetDbSet(dbContext)
.Where(e => e.Email == null ||
EF.Functions.ILike(EF.Functions.Collate(e.Email, "default"), $"{email}%"))
.OrderBy(e => e.Email)
.Skip(skip).Take(take)
.ToListAsync();
}
else
{
users = await GetDbSet(dbContext)
.Where(e => email == null || e.Email.StartsWith(email))
.OrderBy(e => e.Email)
.Skip(skip).Take(take)
.ToListAsync();
}
return Mapper.Map<List<Core.Entities.User>>(users);
}
}
public async Task<ICollection<Core.Entities.User>> GetManyByPremiumAsync(bool premium)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var users = await GetDbSet(dbContext).Where(e => e.Premium == premium).ToListAsync();
return Mapper.Map<List<Core.Entities.User>>(users);
}
}
public async Task<string?> GetPublicKeyAsync(Guid id)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
return await GetDbSet(dbContext).Where(e => e.Id == id).Select(e => e.PublicKey).SingleOrDefaultAsync();
}
}
public async Task<DateTime> GetAccountRevisionDateAsync(Guid id)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
return await GetDbSet(dbContext).Where(e => e.Id == id).Select(e => e.AccountRevisionDate)
.SingleOrDefaultAsync();
}
}
public async Task UpdateStorageAsync(Guid id)
{
await base.UserUpdateStorage(id);
}
public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var user = new User
{
Id = id,
RenewalReminderDate = renewalReminderDate,
};
var set = GetDbSet(dbContext);
set.Attach(user);
dbContext.Entry(user).Property(e => e.RenewalReminderDate).IsModified = true;
await dbContext.SaveChangesAsync();
}
}
public async Task<Core.Entities.User?> GetBySsoUserAsync(string externalId, Guid? organizationId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var ssoUser = await dbContext.SsoUsers.SingleOrDefaultAsync(e =>
e.OrganizationId == organizationId && e.ExternalId == externalId);
if (ssoUser == null)
{
return null;
}
var entity = await dbContext.Users.SingleOrDefaultAsync(e => e.Id == ssoUser.UserId);
return Mapper.Map<Core.Entities.User>(entity);
}
}
/// <inheritdoc />
public async Task UpdateUserKeyAndEncryptedDataAsync(Core.Entities.User user,
IEnumerable<UpdateEncryptedDataForKeyRotation> updateDataActions)
{
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
await using var transaction = await dbContext.Database.BeginTransactionAsync();
try
{
// Update user
var entity = await dbContext.Users.FindAsync(user.Id);
if (entity == null)
{
throw new ArgumentException("User not found", nameof(user));
}
entity.SecurityStamp = user.SecurityStamp;
entity.Key = user.Key;
entity.PrivateKey = user.PrivateKey;
entity.LastKeyRotationDate = user.LastKeyRotationDate;
entity.AccountRevisionDate = user.AccountRevisionDate;
entity.RevisionDate = user.RevisionDate;
await dbContext.SaveChangesAsync();
// Update re-encrypted data
foreach (var action in updateDataActions)
{
// connection and transaction aren't used in EF
await action();
}
await transaction.CommitAsync();
}
catch
{
await transaction.RollbackAsync();
throw;
}
}
public async Task<IEnumerable<Core.Entities.User>> GetManyAsync(IEnumerable<Guid> ids)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var users = dbContext.Users.Where(x => ids.Contains(x.Id));
return await users.ToListAsync();
}
}
public async Task<IEnumerable<DataModel.UserWithCalculatedPremium>> GetManyWithCalculatedPremiumAsync(IEnumerable<Guid> ids)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var users = dbContext.Users.Where(x => ids.Contains(x.Id));
return await users.Select(e => new DataModel.UserWithCalculatedPremium(e)
{
HasPremiumAccess = e.Premium || dbContext.OrganizationUsers
.Any(ou => ou.UserId == e.Id &&
dbContext.Organizations
.Any(o => o.Id == ou.OrganizationId &&
o.UsersGetPremium == true &&
o.Enabled == true))
}).ToListAsync();
}
}
public override async Task DeleteAsync(Core.Entities.User user)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var transaction = await dbContext.Database.BeginTransactionAsync();
dbContext.WebAuthnCredentials.RemoveRange(dbContext.WebAuthnCredentials.Where(w => w.UserId == user.Id));
dbContext.Ciphers.RemoveRange(dbContext.Ciphers.Where(c => c.UserId == user.Id));
dbContext.Folders.RemoveRange(dbContext.Folders.Where(f => f.UserId == user.Id));
dbContext.AuthRequests.RemoveRange(dbContext.AuthRequests.Where(s => s.UserId == user.Id));
dbContext.Devices.RemoveRange(dbContext.Devices.Where(d => d.UserId == user.Id));
var collectionUsers = from cu in dbContext.CollectionUsers
join ou in dbContext.OrganizationUsers on cu.OrganizationUserId equals ou.Id
where ou.UserId == user.Id
select cu;
dbContext.CollectionUsers.RemoveRange(collectionUsers);
var groupUsers = from gu in dbContext.GroupUsers
join ou in dbContext.OrganizationUsers on gu.OrganizationUserId equals ou.Id
where ou.UserId == user.Id
select gu;
dbContext.GroupUsers.RemoveRange(groupUsers);
dbContext.UserProjectAccessPolicy.RemoveRange(
dbContext.UserProjectAccessPolicy.Where(ap => ap.OrganizationUser.UserId == user.Id));
dbContext.UserServiceAccountAccessPolicy.RemoveRange(
dbContext.UserServiceAccountAccessPolicy.Where(ap => ap.OrganizationUser.UserId == user.Id));
dbContext.OrganizationUsers.RemoveRange(dbContext.OrganizationUsers.Where(ou => ou.UserId == user.Id));
dbContext.ProviderUsers.RemoveRange(dbContext.ProviderUsers.Where(pu => pu.UserId == user.Id));
dbContext.SsoUsers.RemoveRange(dbContext.SsoUsers.Where(su => su.UserId == user.Id));
dbContext.EmergencyAccesses.RemoveRange(
dbContext.EmergencyAccesses.Where(ea => ea.GrantorId == user.Id || ea.GranteeId == user.Id));
dbContext.Sends.RemoveRange(dbContext.Sends.Where(s => s.UserId == user.Id));
var mappedUser = Mapper.Map<User>(user);
dbContext.Users.Remove(mappedUser);
await transaction.CommitAsync();
await dbContext.SaveChangesAsync();
}
}
}