mirror of
https://github.com/bitwarden/server.git
synced 2025-05-20 19:14:32 -05:00
Update ReplaceAsync
Implementation in EF CollectionRepository
(#4611)
* Add Collections Tests * Update CollectionRepository Implementation * Test Adding And Deleting Through Replace * Format
This commit is contained in:
parent
db4ff79c91
commit
3d7fe4f8af
@ -47,7 +47,7 @@ public interface ICollectionRepository : IRepository<Collection, Guid>
|
|||||||
/// </summary>
|
/// </summary>
|
||||||
Task<CollectionAdminDetails?> GetByIdWithPermissionsAsync(Guid collectionId, Guid? userId, bool includeAccessRelationships);
|
Task<CollectionAdminDetails?> GetByIdWithPermissionsAsync(Guid collectionId, Guid? userId, bool includeAccessRelationships);
|
||||||
|
|
||||||
Task CreateAsync(Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users);
|
Task CreateAsync(Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users);
|
||||||
Task ReplaceAsync(Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users);
|
Task ReplaceAsync(Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users);
|
||||||
Task DeleteUserAsync(Guid collectionId, Guid organizationUserId);
|
Task DeleteUserAsync(Guid collectionId, Guid organizationUserId);
|
||||||
Task UpdateUsersAsync(Guid id, IEnumerable<CollectionAccessSelection> users);
|
Task UpdateUsersAsync(Guid id, IEnumerable<CollectionAccessSelection> users);
|
||||||
|
@ -50,7 +50,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users)
|
public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users)
|
||||||
{
|
{
|
||||||
await CreateAsync(obj);
|
await CreateAsync(obj);
|
||||||
using (var scope = ServiceScopeFactory.CreateScope())
|
using (var scope = ServiceScopeFactory.CreateScope())
|
||||||
@ -523,6 +523,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
|||||||
await ReplaceCollectionGroupsAsync(dbContext, collection, groups);
|
await ReplaceCollectionGroupsAsync(dbContext, collection, groups);
|
||||||
await ReplaceCollectionUsersAsync(dbContext, collection, users);
|
await ReplaceCollectionUsersAsync(dbContext, collection, users);
|
||||||
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
|
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
|
||||||
|
await dbContext.SaveChangesAsync();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -689,133 +690,75 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> groups)
|
private static async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> groups)
|
||||||
{
|
{
|
||||||
var groupsInOrg = dbContext.Groups.Where(g => g.OrganizationId == collection.OrganizationId);
|
var existingCollectionGroups = await dbContext.CollectionGroups
|
||||||
var modifiedGroupEntities = dbContext.Groups.Where(x => groups.Select(x => x.Id).Contains(x.Id));
|
.Where(cg => cg.CollectionId == collection.Id)
|
||||||
var target = (from cg in dbContext.CollectionGroups
|
.ToDictionaryAsync(cg => cg.GroupId);
|
||||||
join g in modifiedGroupEntities
|
|
||||||
on cg.CollectionId equals collection.Id into s_g
|
|
||||||
from g in s_g.DefaultIfEmpty()
|
|
||||||
where g == null || cg.GroupId == g.Id
|
|
||||||
select new { cg, g }).AsNoTracking();
|
|
||||||
var source = (from g in modifiedGroupEntities
|
|
||||||
from cg in dbContext.CollectionGroups
|
|
||||||
.Where(cg => cg.CollectionId == collection.Id && cg.GroupId == g.Id).DefaultIfEmpty()
|
|
||||||
select new { cg, g }).AsNoTracking();
|
|
||||||
var union = await target
|
|
||||||
.Union(source)
|
|
||||||
.Where(x =>
|
|
||||||
x.cg == null ||
|
|
||||||
((x.g == null || x.g.Id == x.cg.GroupId) &&
|
|
||||||
(x.cg.CollectionId == collection.Id)))
|
|
||||||
.AsNoTracking()
|
|
||||||
.ToListAsync();
|
|
||||||
var insert = union.Where(x => x.cg == null && groupsInOrg.Any(c => x.g.Id == c.Id))
|
|
||||||
.Select(x => new CollectionGroup
|
|
||||||
{
|
|
||||||
CollectionId = collection.Id,
|
|
||||||
GroupId = x.g.Id,
|
|
||||||
ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly,
|
|
||||||
HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords,
|
|
||||||
Manage = groups.FirstOrDefault(g => g.Id == x.g.Id).Manage
|
|
||||||
}).ToList();
|
|
||||||
var update = union
|
|
||||||
.Where(
|
|
||||||
x => x.g != null &&
|
|
||||||
x.cg != null &&
|
|
||||||
(x.cg.ReadOnly != groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly ||
|
|
||||||
x.cg.HidePasswords != groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords ||
|
|
||||||
x.cg.Manage != groups.FirstOrDefault(g => g.Id == x.g.Id).Manage)
|
|
||||||
)
|
|
||||||
.Select(x => new CollectionGroup
|
|
||||||
{
|
|
||||||
CollectionId = collection.Id,
|
|
||||||
GroupId = x.g.Id,
|
|
||||||
ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly,
|
|
||||||
HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords,
|
|
||||||
Manage = groups.FirstOrDefault(g => g.Id == x.g.Id).Manage,
|
|
||||||
});
|
|
||||||
var delete = union
|
|
||||||
.Where(
|
|
||||||
x => x.g == null &&
|
|
||||||
x.cg.CollectionId == collection.Id
|
|
||||||
)
|
|
||||||
.Select(x => new CollectionGroup
|
|
||||||
{
|
|
||||||
CollectionId = collection.Id,
|
|
||||||
GroupId = x.cg.GroupId,
|
|
||||||
})
|
|
||||||
.ToList();
|
|
||||||
|
|
||||||
await dbContext.AddRangeAsync(insert);
|
foreach (var group in groups)
|
||||||
dbContext.UpdateRange(update);
|
{
|
||||||
dbContext.RemoveRange(delete);
|
if (existingCollectionGroups.TryGetValue(group.Id, out var existingCollectionGroup))
|
||||||
await dbContext.SaveChangesAsync();
|
{
|
||||||
|
// It already exists, update it
|
||||||
|
existingCollectionGroup.HidePasswords = group.HidePasswords;
|
||||||
|
existingCollectionGroup.ReadOnly = group.ReadOnly;
|
||||||
|
existingCollectionGroup.Manage = group.Manage;
|
||||||
|
dbContext.CollectionGroups.Update(existingCollectionGroup);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// This is a brand new entry, add it
|
||||||
|
dbContext.CollectionGroups.Add(new CollectionGroup
|
||||||
|
{
|
||||||
|
GroupId = group.Id,
|
||||||
|
CollectionId = collection.Id,
|
||||||
|
HidePasswords = group.HidePasswords,
|
||||||
|
ReadOnly = group.ReadOnly,
|
||||||
|
Manage = group.Manage,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task ReplaceCollectionUsersAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> users)
|
var requestedGroupIds = groups.Select(g => g.Id).ToArray();
|
||||||
{
|
var toDelete = existingCollectionGroups.Values.Where(cg => !requestedGroupIds.Contains(cg.GroupId));
|
||||||
var usersInOrg = dbContext.OrganizationUsers.Where(u => u.OrganizationId == collection.OrganizationId);
|
dbContext.CollectionGroups.RemoveRange(toDelete);
|
||||||
var modifiedUserEntities = dbContext.OrganizationUsers.Where(x => users.Select(x => x.Id).Contains(x.Id));
|
// SaveChangesAsync is expected to be called outside this method
|
||||||
var target = (from cu in dbContext.CollectionUsers
|
}
|
||||||
join u in modifiedUserEntities
|
|
||||||
on cu.CollectionId equals collection.Id into s_g
|
|
||||||
from u in s_g.DefaultIfEmpty()
|
|
||||||
where u == null || cu.OrganizationUserId == u.Id
|
|
||||||
select new { cu, u }).AsNoTracking();
|
|
||||||
var source = (from u in modifiedUserEntities
|
|
||||||
from cu in dbContext.CollectionUsers
|
|
||||||
.Where(cu => cu.CollectionId == collection.Id && cu.OrganizationUserId == u.Id).DefaultIfEmpty()
|
|
||||||
select new { cu, u }).AsNoTracking();
|
|
||||||
var union = await target
|
|
||||||
.Union(source)
|
|
||||||
.Where(x =>
|
|
||||||
x.cu == null ||
|
|
||||||
((x.u == null || x.u.Id == x.cu.OrganizationUserId) &&
|
|
||||||
(x.cu.CollectionId == collection.Id)))
|
|
||||||
.AsNoTracking()
|
|
||||||
.ToListAsync();
|
|
||||||
var insert = union.Where(x => x.u == null && usersInOrg.Any(c => x.u.Id == c.Id))
|
|
||||||
.Select(x => new CollectionUser
|
|
||||||
{
|
|
||||||
CollectionId = collection.Id,
|
|
||||||
OrganizationUserId = x.u.Id,
|
|
||||||
ReadOnly = users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly,
|
|
||||||
HidePasswords = users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords,
|
|
||||||
Manage = users.FirstOrDefault(u => u.Id == x.u.Id).Manage,
|
|
||||||
}).ToList();
|
|
||||||
var update = union
|
|
||||||
.Where(
|
|
||||||
x => x.u != null &&
|
|
||||||
x.cu != null &&
|
|
||||||
(x.cu.ReadOnly != users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly ||
|
|
||||||
x.cu.HidePasswords != users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords ||
|
|
||||||
x.cu.Manage != users.FirstOrDefault(u => u.Id == x.u.Id).Manage)
|
|
||||||
)
|
|
||||||
.Select(x => new CollectionUser
|
|
||||||
{
|
|
||||||
CollectionId = collection.Id,
|
|
||||||
OrganizationUserId = x.u.Id,
|
|
||||||
ReadOnly = users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly,
|
|
||||||
HidePasswords = users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords,
|
|
||||||
Manage = users.FirstOrDefault(u => u.Id == x.u.Id).Manage,
|
|
||||||
});
|
|
||||||
var delete = union
|
|
||||||
.Where(
|
|
||||||
x => x.u == null &&
|
|
||||||
x.cu.CollectionId == collection.Id
|
|
||||||
)
|
|
||||||
.Select(x => new CollectionUser
|
|
||||||
{
|
|
||||||
CollectionId = collection.Id,
|
|
||||||
OrganizationUserId = x.cu.OrganizationUserId,
|
|
||||||
})
|
|
||||||
.ToList();
|
|
||||||
|
|
||||||
await dbContext.AddRangeAsync(insert);
|
private static async Task ReplaceCollectionUsersAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> users)
|
||||||
dbContext.UpdateRange(update);
|
{
|
||||||
dbContext.RemoveRange(delete);
|
var existingCollectionUsers = await dbContext.CollectionUsers
|
||||||
await dbContext.SaveChangesAsync();
|
.Where(cu => cu.CollectionId == collection.Id)
|
||||||
|
.ToDictionaryAsync(cu => cu.OrganizationUserId);
|
||||||
|
|
||||||
|
foreach (var user in users)
|
||||||
|
{
|
||||||
|
if (existingCollectionUsers.TryGetValue(user.Id, out var existingCollectionUser))
|
||||||
|
{
|
||||||
|
// This is an existing entry, update it.
|
||||||
|
existingCollectionUser.HidePasswords = user.HidePasswords;
|
||||||
|
existingCollectionUser.ReadOnly = user.ReadOnly;
|
||||||
|
existingCollectionUser.Manage = user.Manage;
|
||||||
|
dbContext.CollectionUsers.Update(existingCollectionUser);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// This is a brand new entry, add it
|
||||||
|
dbContext.CollectionUsers.Add(new CollectionUser
|
||||||
|
{
|
||||||
|
OrganizationUserId = user.Id,
|
||||||
|
CollectionId = collection.Id,
|
||||||
|
HidePasswords = user.HidePasswords,
|
||||||
|
ReadOnly = user.ReadOnly,
|
||||||
|
Manage = user.Manage,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestedUserIds = users.Select(u => u.Id).ToArray();
|
||||||
|
var toDelete = existingCollectionUsers.Values.Where(cu => !requestedUserIds.Contains(cu.OrganizationUserId));
|
||||||
|
dbContext.CollectionUsers.RemoveRange(toDelete);
|
||||||
|
// SaveChangesAsync is expected to be called outside this method
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,8 +13,8 @@
|
|||||||
<PackageReference Include="Microsoft.Extensions.Logging" Version="8.0.0" />
|
<PackageReference Include="Microsoft.Extensions.Logging" Version="8.0.0" />
|
||||||
<PackageReference Include="Microsoft.Extensions.TimeProvider.Testing" Version="8.6.0" />
|
<PackageReference Include="Microsoft.Extensions.TimeProvider.Testing" Version="8.6.0" />
|
||||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
|
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
|
||||||
<PackageReference Include="xunit" Version="2.4.1" />
|
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.3">
|
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitRunnerVisualStudioVersion)">
|
||||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||||
<PrivateAssets>all</PrivateAssets>
|
<PrivateAssets>all</PrivateAssets>
|
||||||
</PackageReference>
|
</PackageReference>
|
||||||
|
@ -463,4 +463,141 @@ public class CollectionRepositoryTests
|
|||||||
Assert.False(c3.Unmanaged);
|
Assert.False(c3.Unmanaged);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[DatabaseTheory, DatabaseData]
|
||||||
|
public async Task ReplaceAsync_Works(
|
||||||
|
IUserRepository userRepository,
|
||||||
|
IOrganizationRepository organizationRepository,
|
||||||
|
IOrganizationUserRepository organizationUserRepository,
|
||||||
|
IGroupRepository groupRepository,
|
||||||
|
ICollectionRepository collectionRepository)
|
||||||
|
{
|
||||||
|
var user = await userRepository.CreateAsync(new User
|
||||||
|
{
|
||||||
|
Name = "Test User",
|
||||||
|
Email = $"test+{Guid.NewGuid()}@email.com",
|
||||||
|
ApiKey = "TEST",
|
||||||
|
SecurityStamp = "stamp",
|
||||||
|
});
|
||||||
|
|
||||||
|
var organization = await organizationRepository.CreateAsync(new Organization
|
||||||
|
{
|
||||||
|
Name = "Test Org",
|
||||||
|
PlanType = PlanType.EnterpriseAnnually,
|
||||||
|
Plan = "Test Plan",
|
||||||
|
BillingEmail = "billing@email.com"
|
||||||
|
});
|
||||||
|
|
||||||
|
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
|
||||||
|
{
|
||||||
|
OrganizationId = organization.Id,
|
||||||
|
UserId = user.Id,
|
||||||
|
Status = OrganizationUserStatusType.Confirmed,
|
||||||
|
});
|
||||||
|
|
||||||
|
var orgUser2 = await organizationUserRepository.CreateAsync(new OrganizationUser
|
||||||
|
{
|
||||||
|
OrganizationId = organization.Id,
|
||||||
|
UserId = user.Id,
|
||||||
|
Status = OrganizationUserStatusType.Confirmed,
|
||||||
|
});
|
||||||
|
|
||||||
|
var orgUser3 = await organizationUserRepository.CreateAsync(new OrganizationUser
|
||||||
|
{
|
||||||
|
OrganizationId = organization.Id,
|
||||||
|
UserId = user.Id,
|
||||||
|
Status = OrganizationUserStatusType.Confirmed,
|
||||||
|
});
|
||||||
|
|
||||||
|
var group1 = await groupRepository.CreateAsync(new Group
|
||||||
|
{
|
||||||
|
Name = "Test Group #1",
|
||||||
|
OrganizationId = organization.Id,
|
||||||
|
});
|
||||||
|
|
||||||
|
var group2 = await groupRepository.CreateAsync(new Group
|
||||||
|
{
|
||||||
|
Name = "Test Group #2",
|
||||||
|
OrganizationId = organization.Id,
|
||||||
|
});
|
||||||
|
|
||||||
|
var group3 = await groupRepository.CreateAsync(new Group
|
||||||
|
{
|
||||||
|
Name = "Test Group #3",
|
||||||
|
OrganizationId = organization.Id,
|
||||||
|
});
|
||||||
|
|
||||||
|
var collection = new Collection
|
||||||
|
{
|
||||||
|
Name = "Test Collection Name",
|
||||||
|
OrganizationId = organization.Id,
|
||||||
|
};
|
||||||
|
|
||||||
|
await collectionRepository.CreateAsync(collection,
|
||||||
|
[
|
||||||
|
new CollectionAccessSelection { Id = group1.Id, Manage = true, HidePasswords = true, ReadOnly = false, },
|
||||||
|
new CollectionAccessSelection { Id = group2.Id, Manage = false, HidePasswords = false, ReadOnly = true, },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
new CollectionAccessSelection { Id = orgUser1.Id, Manage = true, HidePasswords = false, ReadOnly = true },
|
||||||
|
new CollectionAccessSelection { Id = orgUser2.Id, Manage = false, HidePasswords = true, ReadOnly = false },
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
collection.Name = "Updated Collection Name";
|
||||||
|
|
||||||
|
await collectionRepository.ReplaceAsync(collection,
|
||||||
|
[
|
||||||
|
// Should delete group1
|
||||||
|
new CollectionAccessSelection { Id = group2.Id, Manage = true, HidePasswords = true, ReadOnly = false, },
|
||||||
|
// Should add group3
|
||||||
|
new CollectionAccessSelection { Id = group3.Id, Manage = false, HidePasswords = false, ReadOnly = true, },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
// Should delete orgUser1
|
||||||
|
new CollectionAccessSelection { Id = orgUser2.Id, Manage = false, HidePasswords = false, ReadOnly = true },
|
||||||
|
// Should add orgUser3
|
||||||
|
new CollectionAccessSelection { Id = orgUser3.Id, Manage = true, HidePasswords = false, ReadOnly = true },
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Assert it
|
||||||
|
var info = await collectionRepository.GetByIdWithPermissionsAsync(collection.Id, user.Id, true);
|
||||||
|
|
||||||
|
Assert.NotNull(info);
|
||||||
|
|
||||||
|
Assert.Equal("Updated Collection Name", info.Name);
|
||||||
|
|
||||||
|
var groups = info.Groups.ToArray();
|
||||||
|
|
||||||
|
Assert.Equal(2, groups.Length);
|
||||||
|
|
||||||
|
var actualGroup2 = Assert.Single(groups.Where(g => g.Id == group2.Id));
|
||||||
|
|
||||||
|
Assert.True(actualGroup2.Manage);
|
||||||
|
Assert.True(actualGroup2.HidePasswords);
|
||||||
|
Assert.False(actualGroup2.ReadOnly);
|
||||||
|
|
||||||
|
var actualGroup3 = Assert.Single(groups.Where(g => g.Id == group3.Id));
|
||||||
|
|
||||||
|
Assert.False(actualGroup3.Manage);
|
||||||
|
Assert.False(actualGroup3.HidePasswords);
|
||||||
|
Assert.True(actualGroup3.ReadOnly);
|
||||||
|
|
||||||
|
var users = info.Users.ToArray();
|
||||||
|
|
||||||
|
Assert.Equal(2, users.Length);
|
||||||
|
|
||||||
|
var actualOrgUser2 = Assert.Single(users.Where(u => u.Id == orgUser2.Id));
|
||||||
|
|
||||||
|
Assert.False(actualOrgUser2.Manage);
|
||||||
|
Assert.False(actualOrgUser2.HidePasswords);
|
||||||
|
Assert.True(actualOrgUser2.ReadOnly);
|
||||||
|
|
||||||
|
var actualOrgUser3 = Assert.Single(users.Where(u => u.Id == orgUser3.Id));
|
||||||
|
|
||||||
|
Assert.True(actualOrgUser3.Manage);
|
||||||
|
Assert.False(actualOrgUser3.HidePasswords);
|
||||||
|
Assert.True(actualOrgUser3.ReadOnly);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user