1
0
mirror of https://github.com/bitwarden/server.git synced 2025-07-02 16:42:50 -05:00

[PS-1928] Add BumpAccountRevisionDate methods (#2458)

* Move RevisionDate Bumps to Extension Class

* Add Tests against live databases

* Run Formatting

* Fix Typo

* Fix Test Solution Typo

* Await ReplaceAsync
This commit is contained in:
Justin Baur
2022-12-02 14:24:30 -05:00
committed by GitHub
parent 41db511872
commit efe91fd0d8
25 changed files with 3788 additions and 309 deletions

View File

@ -1,13 +1,10 @@
using System.Text.Json;
using AutoMapper;
using Bit.Core.Enums;
using Bit.Core.Enums.Provider;
using Bit.Infrastructure.EntityFramework.Models;
using Bit.Infrastructure.EntityFramework.Repositories.Queries;
using LinqToDB.Data;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Cipher = Bit.Core.Entities.Cipher;
using User = Bit.Core.Entities.User;
namespace Bit.Infrastructure.EntityFramework.Repositories;
@ -51,68 +48,6 @@ public abstract class BaseEntityFrameworkRepository
}
}
protected async Task UserBumpAccountRevisionDateByCipherId(Cipher cipher)
{
var list = new List<Cipher> { cipher };
await UserBumpAccountRevisionDateByCipherId(list);
}
protected async Task UserBumpAccountRevisionDateByCipherId(IEnumerable<Cipher> ciphers)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
foreach (var cipher in ciphers)
{
var dbContext = GetDatabaseContext(scope);
var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher);
var users = query.Run(dbContext);
await users.ForEachAsync(e =>
{
dbContext.Attach(e);
e.RevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
}
protected async Task UserBumpAccountRevisionDateByOrganizationId(Guid organizationId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = new UserBumpAccountRevisionDateByOrganizationIdQuery(organizationId);
var users = query.Run(dbContext);
await users.ForEachAsync(e =>
{
dbContext.Attach(e);
e.RevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
protected async Task UserBumpAccountRevisionDate(Guid userId)
{
await UserBumpManyAccountRevisionDates(new[] { userId });
}
protected async Task UserBumpManyAccountRevisionDates(ICollection<Guid> userIds)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var users = dbContext.Users.Where(u => userIds.Contains(u.Id));
await users.ForEachAsync(u =>
{
dbContext.Attach(u);
u.RevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
protected async Task OrganizationUpdateStorage(Guid organizationId)
{
using (var scope = ServiceScopeFactory.CreateScope())
@ -197,81 +132,4 @@ public abstract class BaseEntityFrameworkRepository
await dbContext.SaveChangesAsync();
}
}
protected async Task UserBumpAccountRevisionDateByCollectionId(Guid collectionId, Guid organizationId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = from u in dbContext.Users
join ou in dbContext.OrganizationUsers
on u.Id equals ou.UserId
join cu in dbContext.CollectionUsers
on ou.Id equals cu.OrganizationUserId into cu_g
from cu in cu_g.DefaultIfEmpty()
where !ou.AccessAll && cu.CollectionId.Equals(collectionId)
join gu in dbContext.GroupUsers
on ou.Id equals gu.OrganizationUserId into gu_g
from gu in gu_g.DefaultIfEmpty()
where cu.CollectionId == default(Guid) && !ou.AccessAll
join g in dbContext.Groups
on gu.GroupId equals g.Id into g_g
from g in g_g.DefaultIfEmpty()
join cg in dbContext.CollectionGroups
on gu.GroupId equals cg.GroupId into cg_g
from cg in cg_g.DefaultIfEmpty()
where !g.AccessAll && cg.CollectionId == collectionId &&
(ou.OrganizationId == organizationId && ou.Status == OrganizationUserStatusType.Confirmed &&
(cu.CollectionId != default(Guid) || cg.CollectionId != default(Guid) || ou.AccessAll || g.AccessAll))
select new { u, ou, cu, gu, g, cg };
var users = query.Select(x => x.u);
await users.ForEachAsync(u =>
{
dbContext.Attach(u);
u.RevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
protected async Task UserBumpAccountRevisionDateByOrganizationUserId(Guid organizationUserId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = from u in dbContext.Users
join ou in dbContext.OrganizationUsers
on u.Id equals ou.UserId
where ou.Id.Equals(organizationUserId) && ou.Status.Equals(OrganizationUserStatusType.Confirmed)
select new { u, ou };
var users = query.Select(x => x.u);
await users.ForEachAsync(u =>
{
dbContext.Attach(u);
u.AccountRevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
protected async Task UserBumpAccountRevisionDateByProviderUserIds(ICollection<Guid> providerUserIds)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = from pu in dbContext.ProviderUsers
join u in dbContext.Users
on pu.UserId equals u.Id
where pu.Status.Equals(ProviderUserStatusType.Confirmed) &&
providerUserIds.Contains(pu.Id)
select new { pu, u };
var users = query.Select(x => x.u);
await users.ForEachAsync(u =>
{
dbContext.Attach(u);
u.AccountRevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
}

View File

@ -29,40 +29,59 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
var dbContext = GetDatabaseContext(scope);
if (cipher.OrganizationId.HasValue)
{
await UserBumpAccountRevisionDateByCipherId(cipher);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId);
}
else if (cipher.UserId.HasValue)
{
await UserBumpAccountRevisionDate(cipher.UserId.Value);
await dbContext.UserBumpAccountRevisionDateAsync(cipher.UserId.Value);
}
await dbContext.SaveChangesAsync();
}
return cipher;
}
public IQueryable<User> GetBumpedAccountsByCipherId(Core.Entities.Cipher cipher)
public override async Task DeleteAsync(Core.Entities.Cipher cipher)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher);
return query.Run(dbContext);
var cipherInfo = await dbContext.Ciphers
.Where(c => c.Id == cipher.Id)
.Select(c => new { c.UserId, c.OrganizationId, HasAttachments = c.Attachments != null })
.FirstOrDefaultAsync();
await base.DeleteAsync(cipher);
if (cipherInfo?.OrganizationId != null)
{
if (cipherInfo.HasAttachments == true)
{
await OrganizationUpdateStorage(cipherInfo.OrganizationId.Value);
}
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipherInfo.OrganizationId);
}
else if (cipherInfo?.UserId != null)
{
if (cipherInfo.HasAttachments)
{
await UserUpdateStorage(cipherInfo.UserId.Value);
}
await dbContext.UserBumpAccountRevisionDateAsync(cipherInfo.UserId.Value);
}
await dbContext.SaveChangesAsync();
}
}
public async Task CreateAsync(Core.Entities.Cipher cipher, IEnumerable<Guid> collectionIds)
{
cipher = await base.CreateAsync(cipher);
await UpdateCollections(cipher, collectionIds);
}
private async Task UpdateCollections(Core.Entities.Cipher cipher, IEnumerable<Guid> collectionIds)
{
cipher = await CreateAsync(cipher);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var cipherEntity = await dbContext.Ciphers.FindAsync(cipher.Id);
var query = new CipherUpdateCollectionsQuery(cipherEntity, collectionIds).Run(dbContext);
await dbContext.AddRangeAsync(query);
await UpdateCollectionsAsync(dbContext, cipher.Id,
cipher.UserId, cipher.OrganizationId, collectionIds);
await dbContext.SaveChangesAsync();
}
}
@ -88,16 +107,22 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
null;
var entity = Mapper.Map<Cipher>((Core.Entities.Cipher)cipher);
await dbContext.AddAsync(entity);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId.GetValueOrDefault());
await dbContext.SaveChangesAsync();
}
await UserBumpAccountRevisionDateByCipherId(cipher);
return cipher;
}
public async Task CreateAsync(CipherDetails cipher, IEnumerable<Guid> collectionIds)
{
cipher = await CreateAsyncReturnCipher(cipher);
await UpdateCollections(cipher, collectionIds);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await UpdateCollectionsAsync(dbContext, cipher.Id,
cipher.UserId, cipher.OrganizationId, collectionIds);
await dbContext.SaveChangesAsync();
}
}
public async Task CreateAsync(IEnumerable<Core.Entities.Cipher> ciphers, IEnumerable<Core.Entities.Folder> folders)
@ -114,7 +139,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, folderEntities);
var cipherEntities = Mapper.Map<List<Cipher>>(ciphers);
await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities);
await UserBumpAccountRevisionDateByCipherId(ciphers);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(ciphers);
await dbContext.SaveChangesAsync();
}
}
@ -140,7 +166,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionCipherEntities);
}
}
await UserBumpAccountRevisionDateByOrganizationId(ciphers.First().OrganizationId.Value);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(ciphers.First().OrganizationId.Value);
await dbContext.SaveChangesAsync();
}
}
@ -163,13 +190,14 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
if (cipher.OrganizationId.HasValue)
{
await OrganizationUpdateStorage(cipher.OrganizationId.Value);
await UserBumpAccountRevisionDateByCipherId(cipher);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId.Value);
}
else if (cipher.UserId.HasValue)
{
await UserUpdateStorage(cipher.UserId.Value);
await UserBumpAccountRevisionDate(cipher.UserId.Value);
await dbContext.UserBumpAccountRevisionDateAsync(cipher.UserId.Value);
}
await dbContext.SaveChangesAsync();
}
}
@ -184,9 +212,10 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
select c;
dbContext.RemoveRange(ciphers);
await dbContext.SaveChangesAsync();
await OrganizationUpdateStorage(organizationId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId);
await dbContext.SaveChangesAsync();
}
await OrganizationUpdateStorage(organizationId);
await UserBumpAccountRevisionDateByOrganizationId(organizationId);
}
public async Task DeleteByOrganizationIdAsync(Guid organizationId)
@ -207,10 +236,10 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
select c;
dbContext.RemoveRange(ciphers);
await OrganizationUpdateStorage(organizationId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId);
await dbContext.SaveChangesAsync();
}
await OrganizationUpdateStorage(organizationId);
await UserBumpAccountRevisionDateByOrganizationId(organizationId);
}
public async Task DeleteByUserIdAsync(Guid userId)
@ -228,7 +257,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
dbContext.RemoveRange(folders);
await dbContext.SaveChangesAsync();
await UserUpdateStorage(userId);
await UserBumpAccountRevisionDate(userId);
await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync();
}
}
@ -364,8 +394,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
dbContext.Attach(cipher);
cipher.Folders = JsonConvert.SerializeObject(foldersJson);
});
await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDate(userId);
}
}
@ -427,26 +457,100 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
}
var mappedEntity = Mapper.Map<Cipher>((Core.Entities.Cipher)cipher);
dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity);
await UserBumpAccountRevisionDateByCipherId(cipher);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId.GetValueOrDefault());
await dbContext.SaveChangesAsync();
}
}
}
public async Task<bool> ReplaceAsync(Core.Entities.Cipher obj, IEnumerable<Guid> collectionIds)
private static async Task<int> UpdateCollectionsAsync(DatabaseContext context, Guid id, Guid? userId, Guid? organizationId, IEnumerable<Guid> collectionIds)
{
if (!organizationId.HasValue || !collectionIds.Any())
{
return -1;
}
IQueryable<Guid> availableCollectionsQuery;
if (!userId.HasValue)
{
availableCollectionsQuery = context.Collections
.Where(c => c.OrganizationId == organizationId.Value)
.Select(c => c.Id);
}
else
{
availableCollectionsQuery = from c in context.Collections
join o in context.Organizations
on c.OrganizationId equals o.Id
join ou in context.OrganizationUsers
on new { OrganizationId = o.Id, UserId = (Guid?)userId.Value } equals
new { ou.OrganizationId, ou.UserId }
join cu in context.CollectionUsers
on new { ou.AccessAll, CollectionId = c.Id, OrganizationUserId = ou.Id } equals
new { AccessAll = false, cu.CollectionId, cu.OrganizationUserId } into cu_g
from cu in cu_g.DefaultIfEmpty()
join gu in context.GroupUsers
on new { CollectionId = (Guid?)cu.CollectionId, ou.AccessAll, OrganizationUserId = ou.Id } equals
new { CollectionId = (Guid?)null, AccessAll = false, gu.OrganizationUserId } into gu_g
from gu in gu_g.DefaultIfEmpty()
join g in context.Groups
on gu.GroupId equals g.Id into g_g
from g in g_g.DefaultIfEmpty()
join cg in context.CollectionGroups
on new { g.AccessAll, CollectionId = c.Id, gu.GroupId } equals
new { AccessAll = false, cg.CollectionId, cg.GroupId }
where o.Id == organizationId &&
o.Enabled &&
ou.Status == OrganizationUserStatusType.Confirmed &&
(ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)
select c.Id;
}
var availableCollections = await availableCollectionsQuery.ToListAsync();
if (!availableCollections.Any())
{
return -1;
}
var collectionCiphers = collectionIds
.Where(collectionId => availableCollections.Contains(collectionId))
.Select(collectionId => new CollectionCipher
{
CollectionId = collectionId,
CipherId = id,
});
context.CollectionCiphers.AddRange(collectionCiphers);
return 0;
}
public async Task<bool> ReplaceAsync(Core.Entities.Cipher cipher, IEnumerable<Guid> collectionIds)
{
await UpdateCollections(obj, collectionIds);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var cipher = await dbContext.Ciphers.FindAsync(obj.Id);
cipher.UserId = null;
cipher.OrganizationId = obj.OrganizationId;
cipher.Data = obj.Data;
cipher.Attachments = obj.Attachments;
cipher.RevisionDate = obj.RevisionDate;
cipher.DeletedDate = obj.DeletedDate;
await dbContext.SaveChangesAsync();
var transaction = await dbContext.Database.BeginTransactionAsync();
var successes = await UpdateCollectionsAsync(
dbContext, cipher.Id, cipher.UserId,
cipher.OrganizationId, collectionIds);
if (successes < 0)
{
await transaction.CommitAsync();
return false;
}
var trackedCipher = await dbContext.Ciphers.FindAsync(cipher.Id);
trackedCipher.UserId = null;
trackedCipher.OrganizationId = cipher.OrganizationId;
trackedCipher.Data = cipher.Data;
trackedCipher.Attachments = cipher.Attachments;
trackedCipher.RevisionDate = cipher.RevisionDate;
trackedCipher.DeletedDate = cipher.DeletedDate;
await transaction.CommitAsync();
if (!string.IsNullOrWhiteSpace(cipher.Attachments))
{
@ -460,7 +564,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
}
}
await UserBumpAccountRevisionDateByCipherId(cipher);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId.GetValueOrDefault());
await dbContext.SaveChangesAsync();
return true;
}
}
@ -522,13 +627,13 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
foreach (var orgId in orgIds)
{
await OrganizationUpdateStorage(orgId.Value);
await UserBumpAccountRevisionDateByOrganizationId(orgId.Value);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(orgId.Value);
}
if (query.Any(c => c.UserId.HasValue && !string.IsNullOrWhiteSpace(c.Attachments)))
{
await UserUpdateStorage(userId);
}
await UserBumpAccountRevisionDate(userId);
await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync();
return utcNow;
}
@ -547,9 +652,9 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
cipher.DeletedDate = utcNow;
cipher.RevisionDate = utcNow;
});
await dbContext.SaveChangesAsync();
await OrganizationUpdateStorage(organizationId);
await UserBumpAccountRevisionDateByOrganizationId(organizationId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId);
await dbContext.SaveChangesAsync();
}
}
@ -570,13 +675,14 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
if (attachment.OrganizationId.HasValue)
{
await OrganizationUpdateStorage(cipher.OrganizationId.Value);
await UserBumpAccountRevisionDateByCipherId(new List<Core.Entities.Cipher> { cipher });
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId);
}
else if (attachment.UserId.HasValue)
{
await UserUpdateStorage(attachment.UserId.Value);
await UserBumpAccountRevisionDate(attachment.UserId.Value);
await dbContext.UserBumpAccountRevisionDateAsync(attachment.UserId.Value);
}
await dbContext.SaveChangesAsync();
}
}
@ -591,7 +697,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
var dbContext = GetDatabaseContext(scope);
var entities = Mapper.Map<List<Cipher>>(ciphers);
await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, entities);
await UserBumpAccountRevisionDate(userId);
await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync();
}
}
@ -626,8 +733,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
favoritesJson.Remove(userId.ToString());
}
await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDate(userId);
}
}

View File

@ -25,7 +25,8 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec
var organizationId = (await dbContext.Ciphers.FirstOrDefaultAsync(c => c.Id.Equals(obj.CipherId))).OrganizationId;
if (organizationId.HasValue)
{
await UserBumpAccountRevisionDateByCollectionId(obj.CollectionId, organizationId.Value);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(obj.CollectionId, organizationId.Value);
await dbContext.SaveChangesAsync();
}
return obj;
}
@ -132,12 +133,12 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec
});
await dbContext.AddRangeAsync(insert);
dbContext.RemoveRange(delete);
await dbContext.SaveChangesAsync();
if (organizationId.HasValue)
{
await UserBumpAccountRevisionDateByOrganizationId(organizationId.Value);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId.Value);
}
await dbContext.SaveChangesAsync();
}
}
@ -182,8 +183,8 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec
});
await dbContext.AddRangeAsync(insert);
dbContext.RemoveRange(delete);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDateByOrganizationId(organizationId);
}
}
@ -231,7 +232,8 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec
CipherId = cipherId,
};
await dbContext.AddRangeAsync(insertData);
await UserBumpAccountRevisionDateByOrganizationId(organizationId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId);
await dbContext.SaveChangesAsync();
}
}
}

View File

@ -14,16 +14,43 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
: base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Collections)
{ }
public override async Task<Core.Entities.Collection> CreateAsync(Core.Entities.Collection obj)
public override async Task<Core.Entities.Collection> CreateAsync(Core.Entities.Collection collection)
{
await base.CreateAsync(obj);
await UserBumpAccountRevisionDateByCollectionId(obj.Id, obj.OrganizationId);
return obj;
await base.CreateAsync(collection);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
await dbContext.SaveChangesAsync();
}
return collection;
}
public override async Task DeleteAsync(Core.Entities.Collection collection)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
await dbContext.SaveChangesAsync();
}
await base.DeleteAsync(collection);
}
public override async Task UpsertAsync(Core.Entities.Collection collection)
{
await base.UpsertAsync(collection);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
await dbContext.SaveChangesAsync();
}
}
public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable<SelectionReadOnly> groups)
{
await base.CreateAsync(obj);
await CreateAsync(obj);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
@ -40,8 +67,8 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
HidePasswords = g.HidePasswords,
});
await dbContext.AddRangeAsync(collectionGroups);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(obj.OrganizationId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId);
}
}
@ -55,8 +82,8 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
cu.OrganizationUserId == organizationUserId
select cu;
dbContext.RemoveRange(await query.ToListAsync());
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUserId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDateByOrganizationUserId(organizationUserId);
}
}
@ -167,7 +194,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
public async Task ReplaceAsync(Core.Entities.Collection collection, IEnumerable<SelectionReadOnly> groups)
{
await base.ReplaceAsync(collection);
await UpsertAsync(collection);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
@ -228,8 +255,8 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
await dbContext.AddRangeAsync(insert);
dbContext.UpdateRange(update);
dbContext.RemoveRange(delete);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDateByCollectionId(collection.Id, collection.OrganizationId);
}
}
@ -273,7 +300,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
// Remove all existing ones that are no longer requested
var requestedUserIds = requestedUsers.Select(u => u.Id);
dbContext.CollectionUsers.RemoveRange(existingCollectionUsers.Where(cu => !requestedUserIds.Contains(cu.OrganizationUserId)));
await UserBumpAccountRevisionDateByCollectionId(id, organizationId);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(id, organizationId);
await dbContext.SaveChangesAsync();
}
}

View File

@ -0,0 +1,146 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Enums.Provider;
using Bit.Infrastructure.EntityFramework.Repositories.Queries;
using Microsoft.EntityFrameworkCore;
namespace Bit.Infrastructure.EntityFramework.Repositories;
public static class DatabaseContextExtensions
{
public static async Task UserBumpAccountRevisionDateAsync(this DatabaseContext context, Guid userId)
{
var user = await context.Users.FindAsync(userId);
user.AccountRevisionDate = DateTime.UtcNow;
}
public static async Task UserBumpManyAccountRevisionDatesAsync(this DatabaseContext context, ICollection<Guid> userIds)
{
var users = context.Users.Where(u => userIds.Contains(u.Id));
var currentTime = DateTime.UtcNow;
await users.ForEachAsync(u =>
{
context.Attach(u);
u.AccountRevisionDate = currentTime;
});
}
public static async Task UserBumpAccountRevisionDateByOrganizationIdAsync(this DatabaseContext context, Guid organizationId)
{
var users = await (from u in context.Users
join ou in context.OrganizationUsers on u.Id equals ou.UserId
where ou.OrganizationId == organizationId && ou.Status == OrganizationUserStatusType.Confirmed
select u).ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByCipherIdAsync(this DatabaseContext context, Guid cipherId, Guid? organizationId)
{
var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipherId, organizationId);
var users = await query.Run(context).ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByCipherIdAsync(this DatabaseContext context, IEnumerable<Cipher> ciphers)
{
foreach (var cipher in ciphers)
{
await context.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId);
}
}
public static async Task UserBumpAccountRevisionDateByCollectionIdAsync(this DatabaseContext context, Guid collectionId, Guid organizationId)
{
var query = from u in context.Users
join ou in context.OrganizationUsers
on u.Id equals ou.UserId
join cu in context.CollectionUsers
on new { ou.AccessAll, OrganizationUserId = ou.Id, CollectionId = collectionId } equals
new { AccessAll = false, cu.OrganizationUserId, cu.CollectionId } into cu_g
from cu in cu_g.DefaultIfEmpty()
join gu in context.GroupUsers
on new { CollectionId = (Guid?)cu.CollectionId, ou.AccessAll, OrganizationUserId = ou.Id } equals
new { CollectionId = (Guid?)null, AccessAll = false, gu.OrganizationUserId } into gu_g
from gu in gu_g.DefaultIfEmpty()
join g in context.Groups
on gu.GroupId equals g.Id into g_g
from g in g_g.DefaultIfEmpty()
join cg in context.CollectionGroups
on new { g.AccessAll, gu.GroupId, CollectionId = collectionId } equals
new { AccessAll = false, cg.GroupId, cg.CollectionId } into cg_g
from cg in cg_g.DefaultIfEmpty()
where ou.OrganizationId == organizationId &&
ou.Status == OrganizationUserStatusType.Confirmed &&
cg.CollectionId != null &&
ou.AccessAll == true &&
g.AccessAll == true
select u;
var users = await query.ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByOrganizationUserIdAsync(this DatabaseContext context, Guid organizationUserId)
{
var query = from u in context.Users
join ou in context.OrganizationUsers
on u.Id equals ou.UserId
where ou.Id == organizationUserId && ou.Status == OrganizationUserStatusType.Confirmed
select u;
var users = await query.ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByOrganizationUserIdsAsync(this DatabaseContext context, IEnumerable<Guid> organizationUserIds)
{
foreach (var organizationUserId in organizationUserIds)
{
await context.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUserId);
}
}
public static async Task UserBumpAccountRevisionDateByEmergencyAccessGranteeIdAsync(this DatabaseContext context, Guid emergencyAccessId)
{
var query = from u in context.Users
join ea in context.EmergencyAccesses on u.Id equals ea.GranteeId
where ea.Id == emergencyAccessId && ea.Status == EmergencyAccessStatusType.Confirmed
select u;
var users = await query.ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByProviderIdAsync(this DatabaseContext context, Guid providerId)
{
var query = from u in context.Users
join pu in context.ProviderUsers on u.Id equals pu.UserId
where pu.ProviderId == providerId && pu.Status == ProviderUserStatusType.Confirmed
select u;
var users = await query.ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByProviderUserIdAsync(this DatabaseContext context, Guid providerUserId)
{
var query = from u in context.Users
join pu in context.ProviderUsers on u.Id equals pu.UserId
where pu.ProviderId == providerUserId && pu.Status == ProviderUserStatusType.Confirmed
select u;
var users = await query.ToListAsync();
UpdateUserRevisionDate(users);
}
private static void UpdateUserRevisionDate(List<Models.User> users)
{
var time = DateTime.UtcNow;
foreach (var user in users)
{
user.AccountRevisionDate = time;
}
}
}

View File

@ -21,6 +21,17 @@ public class EmergencyAccessRepository : Repository<Core.Entities.EmergencyAcces
return await GetCountFromQuery(query);
}
public override async Task DeleteAsync(Core.Entities.EmergencyAccess emergencyAccess)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByEmergencyAccessGranteeIdAsync(emergencyAccess.Id);
await dbContext.SaveChangesAsync();
}
await base.DeleteAsync(emergencyAccess);
}
public async Task<EmergencyAccessDetails> GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@ -46,6 +46,7 @@ public class GroupRepository : Repository<Core.Entities.Group, Group, Guid>, IGr
gu.OrganizationUserId == organizationUserId
select gu;
dbContext.RemoveRange(await query.ToListAsync());
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUserId);
await dbContext.SaveChangesAsync();
}
}
@ -134,7 +135,8 @@ public class GroupRepository : Repository<Core.Entities.Group, Group, Guid>, IGr
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(obj.OrganizationId);
await dbContext.SaveChangesAsync();
}
}
@ -161,7 +163,8 @@ public class GroupRepository : Repository<Core.Entities.Group, Group, Guid>, IGr
select gu;
dbContext.RemoveRange(delete);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDateByOrganizationId(orgId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(orgId);
await dbContext.SaveChangesAsync();
}
}
}

View File

@ -70,9 +70,39 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var orgUser = await dbContext.FindAsync<OrganizationUser>(organizationUserId);
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUserId);
var orgUser = await dbContext.OrganizationUsers
.Where(ou => ou.Id == organizationUserId)
.FirstAsync();
dbContext.Remove(orgUser);
var organizationId = orgUser?.OrganizationId;
var userId = orgUser?.UserId;
if (orgUser?.OrganizationId != null && orgUser?.UserId != null)
{
var ssoUsers = dbContext.SsoUsers
.Where(su => su.UserId == userId && su.OrganizationId == organizationId);
dbContext.SsoUsers.RemoveRange(ssoUsers);
}
var collectionUsers = dbContext.CollectionUsers
.Where(cu => cu.OrganizationUserId == organizationUserId);
dbContext.CollectionUsers.RemoveRange(collectionUsers);
var groupUsers = dbContext.GroupUsers
.Where(gu => gu.OrganizationUserId == organizationUserId);
dbContext.GroupUsers.RemoveRange(groupUsers);
var orgSponsorships = await dbContext.OrganizationSponsorships
.Where(os => os.SponsoringOrganizationUserId == organizationUserId)
.ToListAsync();
foreach (var orgSponsorship in orgSponsorships)
{
orgSponsorship.ToDelete = true;
}
dbContext.OrganizationUsers.Remove(orgUser);
await dbContext.SaveChangesAsync();
}
}
@ -82,7 +112,9 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdsAsync(organizationUserIds);
var entities = await dbContext.OrganizationUsers
// TODO: Does this work?
.Where(ou => organizationUserIds.Contains(ou.Id))
.ToListAsync();
@ -309,9 +341,20 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
}
}
public async override Task ReplaceAsync(Core.Entities.OrganizationUser organizationUser)
{
await base.ReplaceAsync(organizationUser);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateAsync(organizationUser.UserId.GetValueOrDefault());
await dbContext.SaveChangesAsync();
}
}
public async Task ReplaceAsync(Core.Entities.OrganizationUser obj, IEnumerable<SelectionReadOnly> requestedCollections)
{
await base.ReplaceAsync(obj);
await ReplaceAsync(obj);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
@ -356,7 +399,7 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
var dbContext = GetDatabaseContext(scope);
dbContext.UpdateRange(organizationUsers);
await dbContext.SaveChangesAsync();
await UserBumpManyAccountRevisionDates(organizationUsers
await dbContext.UserBumpManyAccountRevisionDatesAsync(organizationUsers
.Where(ou => ou.UserId.HasValue)
.Select(ou => ou.UserId.Value).ToArray());
}
@ -400,7 +443,7 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
var delete = procedure.Delete.Run(dbContext);
var deleteData = await delete.ToListAsync();
dbContext.RemoveRange(deleteData);
await UserBumpAccountRevisionDateByOrganizationUserId(orgUserId);
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(orgUserId);
await dbContext.SaveChangesAsync();
}
}
@ -449,17 +492,15 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var orgUser = await GetDbSet(dbContext).FindAsync(id);
if (orgUser != null)
var orgUser = await dbContext.OrganizationUsers.FindAsync(id);
if (orgUser == null)
{
dbContext.Update(orgUser);
orgUser.Status = OrganizationUserStatusType.Revoked;
await dbContext.SaveChangesAsync();
if (orgUser.UserId.HasValue)
{
await UserBumpAccountRevisionDate(orgUser.UserId.Value);
}
return;
}
orgUser.Status = OrganizationUserStatusType.Revoked;
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(id);
await dbContext.SaveChangesAsync();
}
}
@ -468,17 +509,17 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var orgUser = await GetDbSet(dbContext).FindAsync(id);
if (orgUser != null)
var orgUser = await dbContext.OrganizationUsers
.FirstOrDefaultAsync(ou => ou.Id == id && ou.Status == OrganizationUserStatusType.Revoked);
if (orgUser == null)
{
dbContext.Update(orgUser);
orgUser.Status = status;
await dbContext.SaveChangesAsync();
if (orgUser.UserId.HasValue)
{
await UserBumpAccountRevisionDate(orgUser.UserId.Value);
}
return;
}
orgUser.Status = status;
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(id);
await dbContext.SaveChangesAsync();
}
}
}

View File

@ -14,6 +14,17 @@ public class ProviderRepository : Repository<Provider, Models.Provider, Guid>, I
: base(serviceScopeFactory, mapper, context => context.Providers)
{ }
public override async Task DeleteAsync(Provider provider)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByProviderIdAsync(provider.Id);
await dbContext.SaveChangesAsync();
}
await base.DeleteAsync(provider);
}
public async Task<ICollection<Provider>> SearchAsync(string name, string userEmail, int skip, int take)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@ -16,6 +16,17 @@ public class ProviderUserRepository :
: base(serviceScopeFactory, mapper, (DatabaseContext context) => context.ProviderUsers)
{ }
public override async Task DeleteAsync(ProviderUser providerUser)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByProviderUserIdAsync(providerUser.Id);
await dbContext.SaveChangesAsync();
}
await base.DeleteAsync(providerUser);
}
public async Task<int> GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers)
{
using (var scope = ServiceScopeFactory.CreateScope())
@ -59,7 +70,10 @@ public class ProviderUserRepository :
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await UserBumpAccountRevisionDateByProviderUserIds(providerUserIds.ToArray());
foreach (var providerUserId in providerUserIds)
{
await dbContext.UserBumpAccountRevisionDateByProviderUserIdAsync(providerUserId);
}
var entities = dbContext.ProviderUsers.Where(pu => providerUserIds.Contains(pu.Id));
dbContext.ProviderUsers.RemoveRange(entities);
await dbContext.SaveChangesAsync();

View File

@ -1,73 +0,0 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using CollectionCipher = Bit.Infrastructure.EntityFramework.Models.CollectionCipher;
namespace Bit.Infrastructure.EntityFramework.Repositories.Queries;
public class CipherUpdateCollectionsQuery : IQuery<CollectionCipher>
{
private readonly Cipher _cipher;
private readonly IEnumerable<Guid> _collectionIds;
public CipherUpdateCollectionsQuery(Cipher cipher, IEnumerable<Guid> collectionIds)
{
_cipher = cipher;
_collectionIds = collectionIds;
}
public virtual IQueryable<CollectionCipher> Run(DatabaseContext dbContext)
{
if (!_cipher.OrganizationId.HasValue || !_collectionIds.Any())
{
return null;
}
var availibleCollections = !_cipher.UserId.HasValue ?
from c in dbContext.Collections
where c.OrganizationId == _cipher.OrganizationId
select c.Id :
from c in dbContext.Collections
join o in dbContext.Organizations
on c.OrganizationId equals o.Id
join ou in dbContext.OrganizationUsers
on new { OrganizationId = o.Id, _cipher.UserId } equals new { ou.OrganizationId, ou.UserId }
join cu in dbContext.CollectionUsers
on new { ou.AccessAll, CollectionId = c.Id, OrganizationUserId = ou.Id } equals
new { AccessAll = false, cu.CollectionId, cu.OrganizationUserId } into cu_g
from cu in cu_g.DefaultIfEmpty()
join gu in dbContext.GroupUsers
on new { CollectionId = (Guid?)cu.CollectionId, ou.AccessAll, OrganizationUserId = ou.Id } equals
new { CollectionId = (Guid?)null, AccessAll = false, gu.OrganizationUserId } into gu_g
from gu in gu_g.DefaultIfEmpty()
join g in dbContext.Groups
on gu.GroupId equals g.Id into g_g
from g in g_g.DefaultIfEmpty()
join cg in dbContext.CollectionGroups
on new { g.AccessAll, CollectionId = c.Id, gu.GroupId } equals
new { AccessAll = false, cg.CollectionId, cg.GroupId } into cg_g
from cg in cg_g.DefaultIfEmpty()
where o.Id == _cipher.OrganizationId &&
o.Enabled &&
ou.Status == OrganizationUserStatusType.Confirmed &&
(ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)
select c.Id;
if (!availibleCollections.Any())
{
return null;
}
var query = from c in availibleCollections
select new CollectionCipher { CollectionId = c, CipherId = _cipher.Id };
return query;
}
}

View File

@ -6,11 +6,19 @@ namespace Bit.Infrastructure.EntityFramework.Repositories.Queries;
public class UserBumpAccountRevisionDateByCipherIdQuery : IQuery<User>
{
private readonly Cipher _cipher;
private readonly Guid _cipherId;
private readonly Guid? _organizationId;
public UserBumpAccountRevisionDateByCipherIdQuery(Cipher cipher)
{
_cipher = cipher;
_cipherId = cipher.Id;
_organizationId = cipher.OrganizationId;
}
public UserBumpAccountRevisionDateByCipherIdQuery(Guid cipherId, Guid? organizationId)
{
_cipherId = cipherId;
_organizationId = organizationId;
}
public IQueryable<User> Run(DatabaseContext dbContext)
@ -21,7 +29,7 @@ public class UserBumpAccountRevisionDateByCipherIdQuery : IQuery<User>
on u.Id equals ou.UserId
join collectionCipher in dbContext.CollectionCiphers
on _cipher.Id equals collectionCipher.CipherId into cc_g
on _cipherId equals collectionCipher.CipherId into cc_g
from cc in cc_g.DefaultIfEmpty()
join collectionUser in dbContext.CollectionUsers
@ -43,7 +51,7 @@ public class UserBumpAccountRevisionDateByCipherIdQuery : IQuery<User>
new { AccessAll = false, collectionGroup.GroupId, collectionGroup.CollectionId } into cg_g
from cg in cg_g.DefaultIfEmpty()
where ou.OrganizationId == _cipher.OrganizationId &&
where ou.OrganizationId == _organizationId &&
ou.Status == OrganizationUserStatusType.Confirmed &&
(cu.CollectionId != null ||
cg.CollectionId != null ||

View File

@ -15,11 +15,17 @@ public class SendRepository : Repository<Core.Entities.Send, Send, Guid>, ISendR
public override async Task<Core.Entities.Send> CreateAsync(Core.Entities.Send send)
{
send = await base.CreateAsync(send);
if (send.UserId.HasValue)
using (var scope = ServiceScopeFactory.CreateScope())
{
await UserUpdateStorage(send.UserId.Value);
await UserBumpAccountRevisionDate(send.UserId.Value);
var dbContext = GetDatabaseContext(scope);
if (send.UserId.HasValue)
{
await UserUpdateStorage(send.UserId.Value);
await dbContext.UserBumpAccountRevisionDateAsync(send.UserId.Value);
await dbContext.SaveChangesAsync();
}
}
return send;
}