diff --git a/src/Api/Controllers/PushController.cs b/src/Api/Controllers/PushController.cs index 0cbc43d87c..1553500010 100644 --- a/src/Api/Controllers/PushController.cs +++ b/src/Api/Controllers/PushController.cs @@ -2,6 +2,8 @@ using Microsoft.AspNetCore.Mvc; using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; +using Bit.Core; +using Bit.Core.Exceptions; namespace Bit.Api.Controllers { @@ -10,16 +12,24 @@ namespace Bit.Api.Controllers public class PushController : Controller { private readonly IPushRegistrationService _pushRegistrationService; + private readonly CurrentContext _currentContext; public PushController( - IPushRegistrationService pushRegistrationService) + IPushRegistrationService pushRegistrationService, + CurrentContext currentContext) { + _currentContext = currentContext; _pushRegistrationService = pushRegistrationService; } [HttpGet("register")] public Object Register() { + if(!_currentContext.InstallationId.HasValue) + { + throw new BadRequestException("bad request."); + } + return new { Foo = "bar" }; } } diff --git a/src/Api/Middleware/CurrentContextMiddleware.cs b/src/Api/Middleware/CurrentContextMiddleware.cs index 29dda8a9a6..f6413d7af6 100644 --- a/src/Api/Middleware/CurrentContextMiddleware.cs +++ b/src/Api/Middleware/CurrentContextMiddleware.cs @@ -1,6 +1,9 @@ using Bit.Core; using Microsoft.AspNetCore.Http; +using System; +using System.Collections.Generic; using System.Linq; +using System.Security.Claims; using System.Threading.Tasks; namespace Bit.Api.Middleware @@ -18,38 +21,49 @@ namespace Bit.Api.Middleware { if(httpContext.User != null) { - var securityStampClaim = httpContext.User.Claims.FirstOrDefault(c => c.Type == "device"); - currentContext.DeviceIdentifier = securityStampClaim?.Value; + var claimsDict = httpContext.User.Claims + .GroupBy(c => c.Type) + .ToDictionary(c => c.Key, c => c.Select(v => v)); - var orgOwnerClaims = httpContext.User.Claims.Where(c => c.Type == "orgowner"); - if(orgOwnerClaims.Any()) + var clientId = GetClaimValue(claimsDict, "client_id"); + var clientSubject = GetClaimValue(claimsDict, "client_sub"); + if((clientId?.StartsWith("installation.") ?? false) && clientSubject != null) { - currentContext.Organizations.AddRange(orgOwnerClaims.Select(c => + Guid idGuid; + if(Guid.TryParse(clientSubject, out idGuid)) + { + currentContext.InstallationId = idGuid; + } + } + + currentContext.DeviceIdentifier = GetClaimValue(claimsDict, "device"); + + if(claimsDict.ContainsKey("orgowner")) + { + currentContext.Organizations.AddRange(claimsDict["orgowner"].Select(c => new CurrentContext.CurrentContentOrganization { - Id = new System.Guid(c.Value), + Id = new Guid(c.Value), Type = Core.Enums.OrganizationUserType.Owner })); } - var orgAdminClaims = httpContext.User.Claims.Where(c => c.Type == "orgadmin"); - if(orgAdminClaims.Any()) + if(claimsDict.ContainsKey("orgadmin")) { - currentContext.Organizations.AddRange(orgAdminClaims.Select(c => + currentContext.Organizations.AddRange(claimsDict["orgadmin"].Select(c => new CurrentContext.CurrentContentOrganization { - Id = new System.Guid(c.Value), + Id = new Guid(c.Value), Type = Core.Enums.OrganizationUserType.Admin })); } - var orgUserClaims = httpContext.User.Claims.Where(c => c.Type == "orguser"); - if(orgUserClaims.Any()) + if(claimsDict.ContainsKey("orguser")) { - currentContext.Organizations.AddRange(orgUserClaims.Select(c => + currentContext.Organizations.AddRange(claimsDict["orguser"].Select(c => new CurrentContext.CurrentContentOrganization { - Id = new System.Guid(c.Value), + Id = new Guid(c.Value), Type = Core.Enums.OrganizationUserType.User })); } @@ -62,5 +76,15 @@ namespace Bit.Api.Middleware await _next.Invoke(httpContext); } + + private string GetClaimValue(Dictionary> claims, string type) + { + if(!claims.ContainsKey(type)) + { + return null; + } + + return claims[type].FirstOrDefault()?.Value; + } } } diff --git a/src/Core/CurrentContext.cs b/src/Core/CurrentContext.cs index b867e6a7bf..3927894e2b 100644 --- a/src/Core/CurrentContext.cs +++ b/src/Core/CurrentContext.cs @@ -11,6 +11,7 @@ namespace Bit.Core public virtual User User { get; set; } public virtual string DeviceIdentifier { get; set; } public virtual List Organizations { get; set; } = new List(); + public virtual Guid? InstallationId { get; set; } public bool OrganizationUser(Guid orgId) { diff --git a/src/Core/IdentityServer/ApiResources.cs b/src/Core/IdentityServer/ApiResources.cs index 70da942a00..4e6dc1e895 100644 --- a/src/Core/IdentityServer/ApiResources.cs +++ b/src/Core/IdentityServer/ApiResources.cs @@ -21,7 +21,7 @@ namespace Bit.Core.IdentityServer "orgadmin", "orguser" }), - new ApiResource("api.push") + new ApiResource("api.push", new string[] { JwtClaimTypes.Subject }) }; } } diff --git a/src/Core/IdentityServer/ClientStore.cs b/src/Core/IdentityServer/ClientStore.cs index 7dfc063e0e..3ea4232682 100644 --- a/src/Core/IdentityServer/ClientStore.cs +++ b/src/Core/IdentityServer/ClientStore.cs @@ -4,6 +4,8 @@ using IdentityServer4.Models; using System.Collections.Generic; using Bit.Core.Repositories; using System; +using System.Security.Claims; +using IdentityModel; namespace Bit.Core.IdentityServer { @@ -37,7 +39,8 @@ namespace Bit.Core.IdentityServer AllowedScopes = new string[] { "api.push" }, AllowedGrantTypes = GrantTypes.ClientCredentials, AccessTokenLifetime = 3600 * 24, - Enabled = installation.Enabled + Enabled = installation.Enabled, + Claims = new List { new Claim(JwtClaimTypes.Subject, installation.Id.ToString()) } }; } }