1
0
mirror of https://github.com/bitwarden/server.git synced 2025-06-07 11:40:31 -05:00

feat(change-password-component): Change Password Update [18720] - Now sending back accepted mp policies on base validator

This commit is contained in:
Patrick Pimentel 2025-06-05 20:39:14 -04:00
parent 8165651285
commit 7ed190006b
No known key found for this signature in database
GPG Key ID: 4B27FC74C6422186
12 changed files with 144 additions and 32 deletions

View File

@ -81,12 +81,15 @@ public class SyncController : Controller
throw new BadRequestException("User not found."); throw new BadRequestException("User not found.");
} }
var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(
user.Id,
OrganizationUserStatusType.Confirmed); OrganizationUserStatusType.Confirmed);
var providerUserDetails = await _providerUserRepository.GetManyDetailsByUserAsync(user.Id, var providerUserDetails = await _providerUserRepository.GetManyDetailsByUserAsync(
user.Id,
ProviderUserStatusType.Confirmed); ProviderUserStatusType.Confirmed);
var providerUserOrganizationDetails = var providerUserOrganizationDetails =
await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(
user.Id,
ProviderUserStatusType.Confirmed); ProviderUserStatusType.Confirmed);
var hasEnabledOrgs = organizationUserDetails.Any(o => o.Enabled); var hasEnabledOrgs = organizationUserDetails.Any(o => o.Enabled);

View File

@ -20,6 +20,7 @@ public interface IPolicyRepository : IRepository<Policy, Guid>
Task<Policy?> GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type); Task<Policy?> GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type);
Task<ICollection<Policy>> GetManyByOrganizationIdAsync(Guid organizationId); Task<ICollection<Policy>> GetManyByOrganizationIdAsync(Guid organizationId);
Task<ICollection<Policy>> GetManyByUserIdAsync(Guid userId); Task<ICollection<Policy>> GetManyByUserIdAsync(Guid userId);
Task<ICollection<Policy>> GetManyAcceptedOrConfirmedByUserIdAsync(Guid userId);
/// <summary> /// <summary>
/// Gets all PolicyDetails for a user for all policy types. /// Gets all PolicyDetails for a user for all policy types.
/// </summary> /// </summary>

View File

@ -11,7 +11,7 @@ public interface IPolicyService
/// <summary> /// <summary>
/// Get the combined master password policy options for the specified user. /// Get the combined master password policy options for the specified user.
/// </summary> /// </summary>
Task<MasterPasswordPolicyData> GetMasterPasswordPolicyForUserAsync(User user); Task<MasterPasswordPolicyData> GetMasterPasswordPolicyForUserAsync(User user, bool getConfirmedOrAccepted = false);
Task<ICollection<OrganizationUserPolicyDetails>> GetPoliciesApplicableToUserAsync(Guid userId, PolicyType policyType, OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted); Task<ICollection<OrganizationUserPolicyDetails>> GetPoliciesApplicableToUserAsync(Guid userId, PolicyType policyType, OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted);
Task<bool> AnyPoliciesApplicableToUserAsync(Guid userId, PolicyType policyType, OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted); Task<bool> AnyPoliciesApplicableToUserAsync(Guid userId, PolicyType policyType, OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted);
} }

View File

@ -29,9 +29,13 @@ public class PolicyService : IPolicyService
_globalSettings = globalSettings; _globalSettings = globalSettings;
} }
public async Task<MasterPasswordPolicyData> GetMasterPasswordPolicyForUserAsync(User user) public async Task<MasterPasswordPolicyData> GetMasterPasswordPolicyForUserAsync(User user, bool getConfirmedOrAccepted = false)
{ {
var policies = (await _policyRepository.GetManyByUserIdAsync(user.Id)) var policies = getConfirmedOrAccepted ?
(await _policyRepository.GetManyAcceptedOrConfirmedByUserIdAsync(user.Id))
.Where(p => p.Type == PolicyType.MasterPassword && p.Enabled)
.ToList()
: (await _policyRepository.GetManyByUserIdAsync(user.Id))
.Where(p => p.Type == PolicyType.MasterPassword && p.Enabled) .Where(p => p.Type == PolicyType.MasterPassword && p.Enabled)
.ToList(); .ToList();

View File

@ -32,6 +32,7 @@ public class CurrentContext : ICurrentContext
public virtual string IpAddress { get; set; } public virtual string IpAddress { get; set; }
public virtual string CountryName { get; set; } public virtual string CountryName { get; set; }
public virtual List<CurrentContextOrganization> Organizations { get; set; } public virtual List<CurrentContextOrganization> Organizations { get; set; }
public virtual List<CurrentContextOrganization> OrganizationsConfirmedOrAccepted { get; set; }
public virtual List<CurrentContextProvider> Providers { get; set; } public virtual List<CurrentContextProvider> Providers { get; set; }
public virtual Guid? InstallationId { get; set; } public virtual Guid? InstallationId { get; set; }
public virtual Guid? OrganizationId { get; set; } public virtual Guid? OrganizationId { get; set; }
@ -481,6 +482,22 @@ public class CurrentContext : ICurrentContext
return Organizations; return Organizations;
} }
public async Task<ICollection<CurrentContextOrganization>> OrganizationAcceptedOrConfirmedAsync(
IOrganizationUserRepository organizationUserRepository, Guid userId)
{
if (OrganizationsConfirmedOrAccepted == null)
{
// If we haven't had our user id set, take the one passed in since we are about to get information
// for them anyways.
UserId ??= userId;
var userOrgs = await organizationUserRepository.GetManyDetailsByUserAsync(userId);
OrganizationsConfirmedOrAccepted = userOrgs.Where(ou => ou.Status == OrganizationUserStatusType.Confirmed || ou.Status == OrganizationUserStatusType.Accepted)
.Select(ou => new CurrentContextOrganization(ou)).ToList();
}
return OrganizationsConfirmedOrAccepted;
}
public async Task<ICollection<CurrentContextProvider>> ProviderMembershipAsync( public async Task<ICollection<CurrentContextProvider>> ProviderMembershipAsync(
IProviderUserRepository providerUserRepository, Guid userId) IProviderUserRepository providerUserRepository, Guid userId)
{ {

View File

@ -70,6 +70,9 @@ public interface ICurrentContext
Task<ICollection<CurrentContextOrganization>> OrganizationMembershipAsync( Task<ICollection<CurrentContextOrganization>> OrganizationMembershipAsync(
IOrganizationUserRepository organizationUserRepository, Guid userId); IOrganizationUserRepository organizationUserRepository, Guid userId);
Task<ICollection<CurrentContextOrganization>> OrganizationAcceptedOrConfirmedAsync(
IOrganizationUserRepository organizationUserRepository, Guid userId);
Task<ICollection<CurrentContextProvider>> ProviderMembershipAsync( Task<ICollection<CurrentContextProvider>> ProviderMembershipAsync(
IProviderUserRepository providerUserRepository, Guid userId); IProviderUserRepository providerUserRepository, Guid userId);

View File

@ -361,7 +361,7 @@ public abstract class BaseRequestValidator<T> where T : class
private async Task<MasterPasswordPolicyResponseModel> GetMasterPasswordPolicyAsync(User user) private async Task<MasterPasswordPolicyResponseModel> GetMasterPasswordPolicyAsync(User user)
{ {
// Check current context/cache to see if user is in any organizations, avoids extra DB call if not // Check current context/cache to see if user is in any organizations, avoids extra DB call if not
var orgs = (await CurrentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id)) var orgs = (await CurrentContext.OrganizationAcceptedOrConfirmedAsync(_organizationUserRepository, user.Id))
.ToList(); .ToList();
if (orgs.Count == 0) if (orgs.Count == 0)
@ -369,7 +369,7 @@ public abstract class BaseRequestValidator<T> where T : class
return null; return null;
} }
return new MasterPasswordPolicyResponseModel(await PolicyService.GetMasterPasswordPolicyForUserAsync(user)); return new MasterPasswordPolicyResponseModel(await PolicyService.GetMasterPasswordPolicyForUserAsync(user, true));
} }
/// <summary> /// <summary>
@ -401,8 +401,8 @@ public abstract class BaseRequestValidator<T> where T : class
/// <summary> /// <summary>
/// Builds the custom response that will be sent to the client upon successful authentication, which /// Builds the custom response that will be sent to the client upon successful authentication, which
/// includes the information needed for the client to initialize the user's account in state. /// includes the information needed for the client to initialize the user's account in state.
/// </summary> /// </summary>
/// <param name="user">The authenticated user.</param> /// <param name="user">The authenticated user.</param>
/// <param name="context">The current request context.</param> /// <param name="context">The current request context.</param>
/// <param name="device">The device used for authentication.</param> /// <param name="device">The device used for authentication.</param>
/// <param name="sendRememberToken">Whether to send a 2FA remember token.</param> /// <param name="sendRememberToken">Whether to send a 2FA remember token.</param>

View File

@ -61,6 +61,19 @@ public class PolicyRepository : Repository<Policy, Guid>, IPolicyRepository
} }
} }
public async Task<ICollection<Policy>> GetManyAcceptedOrConfirmedByUserIdAsync(Guid userId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<Policy>(
$"[{Schema}].[{Table}_ReadAcceptedOrConfirmedByUserId]",
new { UserId = userId },
commandType: CommandType.StoredProcedure);
return results.ToList();
}
}
public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId) public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId)
{ {
using (var connection = new SqlConnection(ConnectionString)) using (var connection = new SqlConnection(ConnectionString))

View File

@ -20,37 +20,41 @@ public class PolicyRepository : Repository<AdminConsoleEntities.Policy, Policy,
public async Task<AdminConsoleEntities.Policy> GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type) public async Task<AdminConsoleEntities.Policy> GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type)
{ {
using (var scope = ServiceScopeFactory.CreateScope()) using var scope = ServiceScopeFactory.CreateScope();
{ var dbContext = GetDatabaseContext(scope);
var dbContext = GetDatabaseContext(scope); var results = await dbContext.Policies
var results = await dbContext.Policies .FirstOrDefaultAsync(p => p.OrganizationId == organizationId && p.Type == type);
.FirstOrDefaultAsync(p => p.OrganizationId == organizationId && p.Type == type); return Mapper.Map<AdminConsoleEntities.Policy>(results);
return Mapper.Map<AdminConsoleEntities.Policy>(results);
}
} }
public async Task<ICollection<AdminConsoleEntities.Policy>> GetManyByOrganizationIdAsync(Guid organizationId) public async Task<ICollection<AdminConsoleEntities.Policy>> GetManyByOrganizationIdAsync(Guid organizationId)
{ {
using (var scope = ServiceScopeFactory.CreateScope()) using var scope = ServiceScopeFactory.CreateScope();
{ var dbContext = GetDatabaseContext(scope);
var dbContext = GetDatabaseContext(scope); var results = await dbContext.Policies
var results = await dbContext.Policies .Where(p => p.OrganizationId == organizationId)
.Where(p => p.OrganizationId == organizationId) .ToListAsync();
.ToListAsync(); return Mapper.Map<List<AdminConsoleEntities.Policy>>(results);
return Mapper.Map<List<AdminConsoleEntities.Policy>>(results);
}
} }
public async Task<ICollection<AdminConsoleEntities.Policy>> GetManyByUserIdAsync(Guid userId) public async Task<ICollection<AdminConsoleEntities.Policy>> GetManyByUserIdAsync(Guid userId)
{ {
using (var scope = ServiceScopeFactory.CreateScope()) using var scope = ServiceScopeFactory.CreateScope();
{ var dbContext = GetDatabaseContext(scope);
var dbContext = GetDatabaseContext(scope);
var query = new PolicyReadByUserIdQuery(userId); var query = new PolicyReadByUserIdQuery(userId);
var results = await query.Run(dbContext).ToListAsync(); var results = await query.Run(dbContext).ToListAsync();
return Mapper.Map<List<AdminConsoleEntities.Policy>>(results); return Mapper.Map<List<AdminConsoleEntities.Policy>>(results);
} }
public async Task<ICollection<AdminConsoleEntities.Policy>> GetManyAcceptedOrConfirmedByUserIdAsync(Guid userId)
{
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
var query = new PolicyReadAcceptedOrConfirmedByUserIdQuery(userId);
var results = await query.Run(dbContext).ToListAsync();
return Mapper.Map<List<AdminConsoleEntities.Policy>>(results);
} }
public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId) public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId)

View File

@ -0,0 +1,31 @@
using Bit.Core.Enums;
using Bit.Infrastructure.EntityFramework.AdminConsole.Models;
using Bit.Infrastructure.EntityFramework.Repositories;
using Bit.Infrastructure.EntityFramework.Repositories.Queries;
namespace Bit.Infrastructure.EntityFramework.AdminConsole.Repositories.Queries;
public class PolicyReadAcceptedOrConfirmedByUserIdQuery : IQuery<Policy>
{
private readonly Guid _userId;
public PolicyReadAcceptedOrConfirmedByUserIdQuery(Guid userId)
{
_userId = userId;
}
public IQueryable<Policy> Run(DatabaseContext dbContext)
{
var query = from p in dbContext.Policies
join ou in dbContext.OrganizationUsers
on p.OrganizationId equals ou.OrganizationId
join o in dbContext.Organizations
on ou.OrganizationId equals o.Id
where ou.UserId == _userId &&
(ou.Status == OrganizationUserStatusType.Confirmed
|| ou.Status == OrganizationUserStatusType.Accepted)
select p;
return query;
}
}

View File

@ -0,0 +1,18 @@
CREATE PROCEDURE [dbo].[Policy_ReadAcceptedOrConfirmedByUserId]
@UserId UNIQUEIDENTIFIER
AS
BEGIN
SET NOCOUNT ON
SELECT
P.*
FROM
[dbo].[PolicyView] P
INNER JOIN
[dbo].[OrganizationUser] OU ON P.[OrganizationId] = OU.[OrganizationId]
INNER JOIN
[dbo].[Organization] O ON OU.[OrganizationId] = O.[Id]
WHERE
OU.[UserId] = @UserId
AND (OU.[Status] = 1 OR OU.[Status] = 2) -- 1 = Accepted, 2 = Confirmed
END

View File

@ -0,0 +1,18 @@
CREATE OR ALTER PROCEDURE [dbo].[Policy_ReadAcceptedOrConfirmedByUserId]
@UserId UNIQUEIDENTIFIER
AS
BEGIN
SET NOCOUNT ON
SELECT
P.*
FROM
[dbo].[PolicyView] P
INNER JOIN
[dbo].[OrganizationUser] OU ON P.[OrganizationId] = OU.[OrganizationId]
INNER JOIN
[dbo].[Organization] O ON OU.[OrganizationId] = O.[Id]
WHERE
OU.[UserId] = @UserId
AND (OU.[Status] = 1 OR OU.[Status] = 2) -- 1 = Accepted, 2 = Confirmed
END