1
0
mirror of https://github.com/bitwarden/server.git synced 2025-06-30 15:42:48 -05:00

Update ReplaceAsync Implementation in EF CollectionRepository (#4611)

* Add Collections Tests

* Update CollectionRepository Implementation

* Test Adding And Deleting Through Replace

* Format
This commit is contained in:
Justin Baur
2024-08-14 13:50:29 -04:00
committed by GitHub
parent db4ff79c91
commit 3d7fe4f8af
4 changed files with 206 additions and 126 deletions

View File

@ -47,7 +47,7 @@ public interface ICollectionRepository : IRepository<Collection, Guid>
/// </summary>
Task<CollectionAdminDetails?> GetByIdWithPermissionsAsync(Guid collectionId, Guid? userId, bool includeAccessRelationships);
Task CreateAsync(Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users);
Task CreateAsync(Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users);
Task ReplaceAsync(Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users);
Task DeleteUserAsync(Guid collectionId, Guid organizationUserId);
Task UpdateUsersAsync(Guid id, IEnumerable<CollectionAccessSelection> users);

View File

@ -50,7 +50,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
}
}
public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users)
public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users)
{
await CreateAsync(obj);
using (var scope = ServiceScopeFactory.CreateScope())
@ -523,6 +523,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
await ReplaceCollectionGroupsAsync(dbContext, collection, groups);
await ReplaceCollectionUsersAsync(dbContext, collection, users);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
await dbContext.SaveChangesAsync();
}
}
@ -689,133 +690,75 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
}
}
private async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> groups)
private static async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> groups)
{
var groupsInOrg = dbContext.Groups.Where(g => g.OrganizationId == collection.OrganizationId);
var modifiedGroupEntities = dbContext.Groups.Where(x => groups.Select(x => x.Id).Contains(x.Id));
var target = (from cg in dbContext.CollectionGroups
join g in modifiedGroupEntities
on cg.CollectionId equals collection.Id into s_g
from g in s_g.DefaultIfEmpty()
where g == null || cg.GroupId == g.Id
select new { cg, g }).AsNoTracking();
var source = (from g in modifiedGroupEntities
from cg in dbContext.CollectionGroups
.Where(cg => cg.CollectionId == collection.Id && cg.GroupId == g.Id).DefaultIfEmpty()
select new { cg, g }).AsNoTracking();
var union = await target
.Union(source)
.Where(x =>
x.cg == null ||
((x.g == null || x.g.Id == x.cg.GroupId) &&
(x.cg.CollectionId == collection.Id)))
.AsNoTracking()
.ToListAsync();
var insert = union.Where(x => x.cg == null && groupsInOrg.Any(c => x.g.Id == c.Id))
.Select(x => new CollectionGroup
{
CollectionId = collection.Id,
GroupId = x.g.Id,
ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly,
HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords,
Manage = groups.FirstOrDefault(g => g.Id == x.g.Id).Manage
}).ToList();
var update = union
.Where(
x => x.g != null &&
x.cg != null &&
(x.cg.ReadOnly != groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly ||
x.cg.HidePasswords != groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords ||
x.cg.Manage != groups.FirstOrDefault(g => g.Id == x.g.Id).Manage)
)
.Select(x => new CollectionGroup
{
CollectionId = collection.Id,
GroupId = x.g.Id,
ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly,
HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords,
Manage = groups.FirstOrDefault(g => g.Id == x.g.Id).Manage,
});
var delete = union
.Where(
x => x.g == null &&
x.cg.CollectionId == collection.Id
)
.Select(x => new CollectionGroup
{
CollectionId = collection.Id,
GroupId = x.cg.GroupId,
})
.ToList();
var existingCollectionGroups = await dbContext.CollectionGroups
.Where(cg => cg.CollectionId == collection.Id)
.ToDictionaryAsync(cg => cg.GroupId);
await dbContext.AddRangeAsync(insert);
dbContext.UpdateRange(update);
dbContext.RemoveRange(delete);
await dbContext.SaveChangesAsync();
foreach (var group in groups)
{
if (existingCollectionGroups.TryGetValue(group.Id, out var existingCollectionGroup))
{
// It already exists, update it
existingCollectionGroup.HidePasswords = group.HidePasswords;
existingCollectionGroup.ReadOnly = group.ReadOnly;
existingCollectionGroup.Manage = group.Manage;
dbContext.CollectionGroups.Update(existingCollectionGroup);
}
else
{
// This is a brand new entry, add it
dbContext.CollectionGroups.Add(new CollectionGroup
{
GroupId = group.Id,
CollectionId = collection.Id,
HidePasswords = group.HidePasswords,
ReadOnly = group.ReadOnly,
Manage = group.Manage,
});
}
}
var requestedGroupIds = groups.Select(g => g.Id).ToArray();
var toDelete = existingCollectionGroups.Values.Where(cg => !requestedGroupIds.Contains(cg.GroupId));
dbContext.CollectionGroups.RemoveRange(toDelete);
// SaveChangesAsync is expected to be called outside this method
}
private async Task ReplaceCollectionUsersAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> users)
private static async Task ReplaceCollectionUsersAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> users)
{
var usersInOrg = dbContext.OrganizationUsers.Where(u => u.OrganizationId == collection.OrganizationId);
var modifiedUserEntities = dbContext.OrganizationUsers.Where(x => users.Select(x => x.Id).Contains(x.Id));
var target = (from cu in dbContext.CollectionUsers
join u in modifiedUserEntities
on cu.CollectionId equals collection.Id into s_g
from u in s_g.DefaultIfEmpty()
where u == null || cu.OrganizationUserId == u.Id
select new { cu, u }).AsNoTracking();
var source = (from u in modifiedUserEntities
from cu in dbContext.CollectionUsers
.Where(cu => cu.CollectionId == collection.Id && cu.OrganizationUserId == u.Id).DefaultIfEmpty()
select new { cu, u }).AsNoTracking();
var union = await target
.Union(source)
.Where(x =>
x.cu == null ||
((x.u == null || x.u.Id == x.cu.OrganizationUserId) &&
(x.cu.CollectionId == collection.Id)))
.AsNoTracking()
.ToListAsync();
var insert = union.Where(x => x.u == null && usersInOrg.Any(c => x.u.Id == c.Id))
.Select(x => new CollectionUser
{
CollectionId = collection.Id,
OrganizationUserId = x.u.Id,
ReadOnly = users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly,
HidePasswords = users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords,
Manage = users.FirstOrDefault(u => u.Id == x.u.Id).Manage,
}).ToList();
var update = union
.Where(
x => x.u != null &&
x.cu != null &&
(x.cu.ReadOnly != users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly ||
x.cu.HidePasswords != users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords ||
x.cu.Manage != users.FirstOrDefault(u => u.Id == x.u.Id).Manage)
)
.Select(x => new CollectionUser
{
CollectionId = collection.Id,
OrganizationUserId = x.u.Id,
ReadOnly = users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly,
HidePasswords = users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords,
Manage = users.FirstOrDefault(u => u.Id == x.u.Id).Manage,
});
var delete = union
.Where(
x => x.u == null &&
x.cu.CollectionId == collection.Id
)
.Select(x => new CollectionUser
{
CollectionId = collection.Id,
OrganizationUserId = x.cu.OrganizationUserId,
})
.ToList();
var existingCollectionUsers = await dbContext.CollectionUsers
.Where(cu => cu.CollectionId == collection.Id)
.ToDictionaryAsync(cu => cu.OrganizationUserId);
await dbContext.AddRangeAsync(insert);
dbContext.UpdateRange(update);
dbContext.RemoveRange(delete);
await dbContext.SaveChangesAsync();
foreach (var user in users)
{
if (existingCollectionUsers.TryGetValue(user.Id, out var existingCollectionUser))
{
// This is an existing entry, update it.
existingCollectionUser.HidePasswords = user.HidePasswords;
existingCollectionUser.ReadOnly = user.ReadOnly;
existingCollectionUser.Manage = user.Manage;
dbContext.CollectionUsers.Update(existingCollectionUser);
}
else
{
// This is a brand new entry, add it
dbContext.CollectionUsers.Add(new CollectionUser
{
OrganizationUserId = user.Id,
CollectionId = collection.Id,
HidePasswords = user.HidePasswords,
ReadOnly = user.ReadOnly,
Manage = user.Manage,
});
}
}
var requestedUserIds = users.Select(u => u.Id).ToArray();
var toDelete = existingCollectionUsers.Values.Where(cu => !requestedUserIds.Contains(cu.OrganizationUserId));
dbContext.CollectionUsers.RemoveRange(toDelete);
// SaveChangesAsync is expected to be called outside this method
}
}