1
0
mirror of https://github.com/bitwarden/server.git synced 2025-07-01 08:02:49 -05:00

Refactor policy checks (#1536)

* Move policy checking logic inside PolicyService

* Refactor to use currentContext.ManagePolicies

* Make orgUser status check more semantic

* Fix single org user checks

* Use CoreHelper implementation to deserialize json

* Refactor policy checks to use db query

* Use new db query for enforcing 2FA Policy

* Add Policy_ReadByTypeApplicableToUser

* Stub out EF implementations

* Refactor: use PolicyRepository only

* Refactor tests

* Copy SQL queries to proj and update sqlproj file

* Refactor importCiphersAsync to use new method

* Add EF implementations and tests

* Refactor SQL to remove unnecessary operations
This commit is contained in:
Thomas Rittson
2021-09-28 06:54:28 +10:00
committed by GitHub
parent fbf3e0dcdc
commit 66629b2f1c
18 changed files with 505 additions and 197 deletions

View File

@ -54,5 +54,30 @@ namespace Bit.Core.Repositories.EntityFramework
return Mapper.Map<List<TableModel.Policy>>(results);
}
}
public async Task<ICollection<Policy>> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType,
OrganizationUserStatusType minStatus)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = new PolicyReadByTypeApplicableToUserQuery(userId, policyType, minStatus);
var results = await query.Run(dbContext).ToListAsync();
return Mapper.Map<List<TableModel.Policy>>(results);
}
}
public async Task<int> GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType,
OrganizationUserStatusType minStatus)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = new PolicyReadByTypeApplicableToUserQuery(userId, policyType, minStatus);
return await GetCountFromQuery(query);
}
}
}
}

View File

@ -0,0 +1,45 @@
using System.Collections.Generic;
using System.Linq;
using Bit.Core.Enums;
using Bit.Core.Models.EntityFramework;
using System;
namespace Bit.Core.Repositories.EntityFramework.Queries
{
public class PolicyReadByTypeApplicableToUserQuery : IQuery<Policy>
{
private readonly Guid _userId;
private readonly PolicyType _policyType;
private readonly OrganizationUserStatusType _minimumStatus;
public PolicyReadByTypeApplicableToUserQuery(Guid userId, PolicyType policyType, OrganizationUserStatusType minimumStatus)
{
_userId = userId;
_policyType = policyType;
_minimumStatus = minimumStatus;
}
public IQueryable<Policy> Run(DatabaseContext dbContext)
{
var providerOrganizations = from pu in dbContext.ProviderUsers
where pu.UserId == _userId
join po in dbContext.ProviderOrganizations
on pu.ProviderId equals po.ProviderId
select po;
var query = from p in dbContext.Policies
join ou in dbContext.OrganizationUsers
on p.OrganizationId equals ou.OrganizationId
where ou.UserId == _userId &&
p.Type == _policyType &&
p.Enabled &&
ou.Status >= _minimumStatus &&
ou.Type >= OrganizationUserType.User &&
(ou.Permissions == null ||
ou.Permissions.Contains($"\"managePolicies\":false")) &&
!providerOrganizations.Any(po => po.OrganizationId == p.OrganizationId)
select p;
return query;
}
}
}

View File

@ -11,5 +11,9 @@ namespace Bit.Core.Repositories
Task<Policy> GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type);
Task<ICollection<Policy>> GetManyByOrganizationIdAsync(Guid organizationId);
Task<ICollection<Policy>> GetManyByUserIdAsync(Guid userId);
Task<ICollection<Policy>> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType,
OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted);
Task<int> GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType,
OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted);
}
}

View File

@ -58,5 +58,33 @@ namespace Bit.Core.Repositories.SqlServer
return results.ToList();
}
}
public async Task<ICollection<Policy>> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType,
OrganizationUserStatusType minStatus)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<Policy>(
$"[{Schema}].[{Table}_ReadByTypeApplicableToUser]",
new { UserId = userId, PolicyType = policyType, MinimumStatus = minStatus },
commandType: CommandType.StoredProcedure);
return results.ToList();
}
}
public async Task<int> GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType,
OrganizationUserStatusType minStatus)
{
using (var connection = new SqlConnection(ConnectionString))
{
var result = await connection.ExecuteScalarAsync<int>(
$"[{Schema}].[{Table}_CountByTypeApplicableToUser]",
new { UserId = userId, PolicyType = policyType, MinimumStatus = minStatus },
commandType: CommandType.StoredProcedure);
return result;
}
}
}
}

View File

@ -26,7 +26,6 @@ namespace Bit.Core.Services
private readonly ICollectionRepository _collectionRepository;
private readonly IUserRepository _userRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly ICollectionCipherRepository _collectionCipherRepository;
private readonly IPushNotificationService _pushService;
private readonly IAttachmentStorageService _attachmentStorageService;
@ -43,7 +42,6 @@ namespace Bit.Core.Services
ICollectionRepository collectionRepository,
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionCipherRepository collectionCipherRepository,
IPushNotificationService pushService,
IAttachmentStorageService attachmentStorageService,
@ -58,7 +56,6 @@ namespace Bit.Core.Services
_collectionRepository = collectionRepository;
_userRepository = userRepository;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_collectionCipherRepository = collectionCipherRepository;
_pushService = pushService;
_attachmentStorageService = attachmentStorageService;
@ -139,19 +136,11 @@ namespace Bit.Core.Services
else
{
// Make sure the user can save new ciphers to their personal vault
var userPolicies = await _policyRepository.GetManyByUserIdAsync(savingUserId);
if (userPolicies != null)
var personalOwnershipPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(savingUserId,
PolicyType.PersonalOwnership);
if (personalOwnershipPolicyCount > 0)
{
foreach (var policy in userPolicies.Where(p => p.Enabled && p.Type == PolicyType.PersonalOwnership))
{
var org = await _organizationUserRepository.GetDetailsByUserAsync(savingUserId, policy.OrganizationId,
OrganizationUserStatusType.Confirmed);
if (org != null && org.Enabled && org.UsePolicies
&& org.Type != OrganizationUserType.Admin && org.Type != OrganizationUserType.Owner)
{
throw new BadRequestException("Due to an Enterprise Policy, you are restricted from saving items to your personal vault.");
}
}
throw new BadRequestException("Due to an Enterprise Policy, you are restricted from saving items to your personal vault.");
}
await _cipherRepository.CreateAsync(cipher);
}
@ -688,26 +677,13 @@ namespace Bit.Core.Services
{
var userId = folders.FirstOrDefault()?.UserId ?? ciphers.FirstOrDefault()?.UserId;
// Check user is allowed to import to personal vault
if (userId.HasValue)
// Make sure the user can save new ciphers to their personal vault
var personalOwnershipPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value,
PolicyType.PersonalOwnership);
if (personalOwnershipPolicyCount > 0)
{
var policies = await _policyRepository.GetManyByUserIdAsync(userId.Value);
var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(userId.Value);
var orgsWithBlockingPolicy = policies
.Where(p => p.Enabled && p.Type == PolicyType.PersonalOwnership)
.Select(p => p.OrganizationId);
var blockedByPolicy = allOrgUsers.Any(ou =>
ou.Type != OrganizationUserType.Owner &&
ou.Type != OrganizationUserType.Admin &&
ou.Status != OrganizationUserStatusType.Invited &&
orgsWithBlockingPolicy.Contains(ou.OrganizationId));
if (blockedByPolicy)
{
throw new BadRequestException("You cannot import items into your personal vault because you are " +
"a member of an organization which forbids it.");
}
throw new BadRequestException("You cannot import items into your personal vault because you are " +
"a member of an organization which forbids it.");
}
foreach (var cipher in ciphers)

View File

@ -639,16 +639,8 @@ namespace Bit.Core.Services
private async Task ValidateSignUpPoliciesAsync(Guid ownerId)
{
var policies = await _policyRepository.GetManyByUserIdAsync(ownerId);
var orgUsers = await _organizationUserRepository.GetManyByUserAsync(ownerId);
var orgsWithSingleOrgPolicy = policies.Where(p => p.Enabled && p.Type == PolicyType.SingleOrg)
.Select(p => p.OrganizationId);
var blockedBySingleOrgPolicy = orgUsers.Any(ou => ou is { Type: OrganizationUserType.Owner } &&
ou.Type != OrganizationUserType.Admin &&
ou.Status != OrganizationUserStatusType.Invited &&
orgsWithSingleOrgPolicy.Contains(ou.OrganizationId));
if (blockedBySingleOrgPolicy)
var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(ownerId, PolicyType.SingleOrg);
if (singleOrgPolicyCount > 0)
{
throw new BadRequestException("You may not create an organization. You belong to an organization " +
"which has a policy that prohibits you from being a member of any other organization.");
@ -1324,48 +1316,37 @@ namespace Bit.Core.Services
}
}
bool notExempt(OrganizationUser organizationUser)
{
return organizationUser.Type != OrganizationUserType.Owner &&
organizationUser.Type != OrganizationUserType.Admin;
}
var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(user.Id);
// Enforce Single Organization Policy of organization user is trying to join
var thisSingleOrgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgUser.OrganizationId, PolicyType.SingleOrg);
if (thisSingleOrgPolicy != null &&
thisSingleOrgPolicy.Enabled &&
notExempt(orgUser) &&
allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId))
var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(user.Id);
var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId);
var invitedSingleOrgPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id,
PolicyType.SingleOrg, OrganizationUserStatusType.Invited);
if (hasOtherOrgs && invitedSingleOrgPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId))
{
throw new BadRequestException("You may not join this organization until you leave or remove " +
"all other organizations.");
}
// Enforce Single Organization Policy of other organizations user is a member of
var policies = await _policyRepository.GetManyByUserIdAsync(user.Id);
var orgsWithSingleOrgPolicy = policies.Where(p => p.Enabled && p.Type == PolicyType.SingleOrg)
.Select(p => p.OrganizationId);
var blockedBySingleOrgPolicy = allOrgUsers.Any(ou => notExempt(ou) &&
ou.Status != OrganizationUserStatusType.Invited &&
orgsWithSingleOrgPolicy.Contains(ou.OrganizationId));
if (blockedBySingleOrgPolicy)
var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(user.Id,
PolicyType.SingleOrg);
if (singleOrgPolicyCount > 0)
{
throw new BadRequestException("You cannot join this organization because you are a member of " +
"an organization which forbids it");
"another organization which forbids it");
}
var twoFactorPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgUser.OrganizationId, PolicyType.TwoFactorAuthentication);
if (!await userService.TwoFactorIsEnabledAsync(user) &&
twoFactorPolicy != null &&
twoFactorPolicy.Enabled &&
notExempt(orgUser))
// Enforce Two Factor Authentication Policy of organization user is trying to join
if (!await userService.TwoFactorIsEnabledAsync(user))
{
throw new BadRequestException("You cannot join this organization until you enable " +
"two-step login on your user account.");
var invitedTwoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id,
PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited);
if (invitedTwoFactorPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId))
{
throw new BadRequestException("You cannot join this organization until you enable " +
"two-step login on your user account.");
}
}
orgUser.Status = OrganizationUserStatusType.Accepted;

View File

@ -10,6 +10,7 @@ using Bit.Core.Models.Data;
using Bit.Core.Models.Table;
using Bit.Core.Repositories;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Identity;
using Newtonsoft.Json;
@ -280,40 +281,24 @@ namespace Bit.Core.Services
return;
}
var policies = await _policyRepository.GetManyByUserIdAsync(userId.Value);
if (policies == null)
var disableSendPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value,
PolicyType.DisableSend);
if (disableSendPolicyCount > 0)
{
return;
}
foreach (var policy in policies.Where(p => p.Enabled && p.Type == PolicyType.DisableSend))
{
if (!await _currentContext.ManagePolicies(policy.OrganizationId))
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
if (send.HideEmail.GetValueOrDefault())
{
foreach (var policy in policies.Where(p => p.Enabled && p.Type == PolicyType.SendOptions))
var sendOptionsPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId.Value, PolicyType.SendOptions);
foreach (var policy in sendOptionsPolicies)
{
if (await _currentContext.ManagePolicies(policy.OrganizationId))
{
continue;
}
SendOptionsPolicyData data = null;
if (policy.Data != null)
{
data = JsonConvert.DeserializeObject<SendOptionsPolicyData>(policy.Data);
}
var data = CoreHelpers.LoadClassFromJsonData<SendOptionsPolicyData>(policy.Data);
if (data?.DisableHideEmail ?? false)
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
}
}

View File

@ -1303,24 +1303,18 @@ namespace Bit.Core.Services
private async Task CheckPoliciesOnTwoFactorRemovalAsync(User user, IOrganizationService organizationService)
{
var policies = await _policyRepository.GetManyByUserIdAsync(user.Id);
var twoFactorPolicies = policies.Where(p => p.Type == PolicyType.TwoFactorAuthentication && p.Enabled);
if (twoFactorPolicies.Any())
var twoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id,
PolicyType.TwoFactorAuthentication);
var removeOrgUserTasks = twoFactorPolicies.Select(async p =>
{
var userOrgs = await _organizationUserRepository.GetManyByUserAsync(user.Id);
var ownerOrgs = userOrgs.Where(o => o.Type == OrganizationUserType.Owner)
.Select(o => o.OrganizationId).ToHashSet();
foreach (var policy in twoFactorPolicies)
{
if (!ownerOrgs.Contains(policy.OrganizationId))
{
await organizationService.DeleteUserAsync(policy.OrganizationId, user.Id);
var organization = await _organizationRepository.GetByIdAsync(policy.OrganizationId);
await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(
organization.Name, user.Email);
}
}
}
await organizationService.DeleteUserAsync(p.OrganizationId, user.Id);
var organization = await _organizationRepository.GetByIdAsync(p.OrganizationId);
await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(
organization.Name, user.Email);
}).ToArray();
await Task.WhenAll(removeOrgUserTasks);
}
public override async Task<IdentityResult> ConfirmEmailAsync(User user, string token)