diff --git a/src/Api/AdminConsole/Authorization/ClaimsExtensions.cs b/src/Api/AdminConsole/Authorization/ClaimsExtensions.cs index 5aa6c1fd77..ef55a0ebbb 100644 --- a/src/Api/AdminConsole/Authorization/ClaimsExtensions.cs +++ b/src/Api/AdminConsole/Authorization/ClaimsExtensions.cs @@ -4,23 +4,29 @@ using System.Security.Claims; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Identity; +using Bit.Core.Models.Data; namespace Bit.Api.AdminConsole.Authorization; public static class ClaimsExtensions { + // Relevant claim types for organization roles, SM access, and custom permissions + private static readonly IEnumerable _relevantClaimTypes = new List{ + Claims.OrganizationOwner, + Claims.OrganizationAdmin, + Claims.OrganizationCustom, + Claims.OrganizationUser, + Claims.SecretsManagerAccess, + }.Concat(new Permissions().ClaimsMap.Select(c => c.ClaimName)); + public static CurrentContextOrganization? GetCurrentContextOrganization(this ClaimsPrincipal user, Guid organizationId) { var claimsDict = user.Claims + .Where(c => _relevantClaimTypes.Contains(c.Type) && Guid.TryParse(c.Value, out _)) .GroupBy(c => c.Type) - .ToDictionary(c => c.Key, c => c.Select(v => v)); - - var accessSecretsManager = claimsDict.TryGetValue(Claims.SecretsManagerAccess, out var value) - ? value - .Where(s => Guid.TryParse(s.Value, out _)) - .Select(s => new Guid(s.Value)) - .ToHashSet() - : []; + .ToDictionary( + c => c.Key, + c => c.Select(v => new Guid(v.Value))); var role = claimsDict.GetRoleForOrganizationId(organizationId); if (!role.HasValue) @@ -33,19 +39,19 @@ public static class ClaimsExtensions { Id = organizationId, Type = role.Value, - AccessSecretsManager = accessSecretsManager.Contains(organizationId), + AccessSecretsManager = claimsDict.ContainsOrganizationId(Claims.SecretsManagerAccess, organizationId), Permissions = role == OrganizationUserType.Custom - ? CurrentContext.SetOrganizationPermissionsFromClaims(organizationId.ToString(), claimsDict) + ? claimsDict.GetPermissionsFromClaims(organizationId) : null }; } - private static bool ContainsOrganizationId(this Dictionary> claimsDict, string claimType, + private static bool ContainsOrganizationId(this Dictionary> claimsDict, string claimType, Guid organizationId) => claimsDict.TryGetValue(claimType, out var claimValue) && - claimValue.Any(c => c.Value.EqualsGuid(organizationId)); + claimValue.Any(guid => guid == organizationId); - private static OrganizationUserType? GetRoleForOrganizationId(this Dictionary> claimsDict, + private static OrganizationUserType? GetRoleForOrganizationId(this Dictionary> claimsDict, Guid organizationId) { if (claimsDict.ContainsOrganizationId(Claims.OrganizationOwner, organizationId)) @@ -71,6 +77,22 @@ public static class ClaimsExtensions return null; } - private static bool EqualsGuid(this string value, Guid guid) - => Guid.TryParse(value, out var parsedValue) && parsedValue == guid; + private static Permissions GetPermissionsFromClaims(this Dictionary> claimsDict, Guid organizationId) + { + return new Permissions + { + AccessEventLogs = claimsDict.ContainsOrganizationId(Claims.AccessEventLogs, organizationId), + AccessImportExport = claimsDict.ContainsOrganizationId(Claims.AccessImportExport, organizationId), + AccessReports = claimsDict.ContainsOrganizationId(Claims.AccessReports, organizationId), + CreateNewCollections = claimsDict.ContainsOrganizationId(Claims.CreateNewCollections, organizationId), + EditAnyCollection = claimsDict.ContainsOrganizationId(Claims.EditAnyCollection, organizationId), + DeleteAnyCollection = claimsDict.ContainsOrganizationId(Claims.DeleteAnyCollection, organizationId), + ManageGroups = claimsDict.ContainsOrganizationId(Claims.ManageGroups, organizationId), + ManagePolicies = claimsDict.ContainsOrganizationId(Claims.ManagePolicies, organizationId), + ManageSso = claimsDict.ContainsOrganizationId(Claims.ManageSso, organizationId), + ManageUsers = claimsDict.ContainsOrganizationId(Claims.ManageUsers, organizationId), + ManageResetPassword = claimsDict.ContainsOrganizationId(Claims.ManageResetPassword, organizationId), + ManageScim = claimsDict.ContainsOrganizationId(Claims.ManageScim, organizationId), + }; + } } diff --git a/src/Core/AdminConsole/Models/Data/Permissions.cs b/src/Core/AdminConsole/Models/Data/Permissions.cs index 9edc3f1d50..a50f08e0f7 100644 --- a/src/Core/AdminConsole/Models/Data/Permissions.cs +++ b/src/Core/AdminConsole/Models/Data/Permissions.cs @@ -1,4 +1,5 @@ using System.Text.Json.Serialization; +using Bit.Core.Identity; namespace Bit.Core.Models.Data; @@ -20,17 +21,17 @@ public class Permissions [JsonIgnore] public List<(bool Permission, string ClaimName)> ClaimsMap => new() { - (AccessEventLogs, "accesseventlogs"), - (AccessImportExport, "accessimportexport"), - (AccessReports, "accessreports"), - (CreateNewCollections, "createnewcollections"), - (EditAnyCollection, "editanycollection"), - (DeleteAnyCollection, "deleteanycollection"), - (ManageGroups, "managegroups"), - (ManagePolicies, "managepolicies"), - (ManageSso, "managesso"), - (ManageUsers, "manageusers"), - (ManageResetPassword, "manageresetpassword"), - (ManageScim, "managescim"), + (AccessEventLogs, Claims.AccessEventLogs), + (AccessImportExport, Claims.AccessImportExport), + (AccessReports, Claims.AccessReports), + (CreateNewCollections, Claims.CreateNewCollections), + (EditAnyCollection, Claims.EditAnyCollection), + (DeleteAnyCollection, Claims.DeleteAnyCollection), + (ManageGroups, Claims.ManageGroups), + (ManagePolicies, Claims.ManagePolicies), + (ManageSso, Claims.ManageSso), + (ManageUsers, Claims.ManageUsers), + (ManageResetPassword, Claims.ManageResetPassword), + (ManageScim, Claims.ManageScim), }; } diff --git a/src/Core/Identity/Claims.cs b/src/Core/Identity/Claims.cs index 65d5eb210a..9478ab3ca8 100644 --- a/src/Core/Identity/Claims.cs +++ b/src/Core/Identity/Claims.cs @@ -22,4 +22,18 @@ public static class Claims // General public const string Type = "type"; + + // Organization permissions + public const string AccessEventLogs = "accesseventlogs"; + public const string AccessImportExport = "accessimportexport"; + public const string AccessReports = "accessreports"; + public const string CreateNewCollections = "createnewcollections"; + public const string EditAnyCollection = "editanycollection"; + public const string DeleteAnyCollection = "deleteanycollection"; + public const string ManageGroups = "managegroups"; + public const string ManagePolicies = "managepolicies"; + public const string ManageSso = "managesso"; + public const string ManageUsers = "manageusers"; + public const string ManageResetPassword = "manageresetpassword"; + public const string ManageScim = "managescim"; }