From f8f7c339c36ba42c597e81e67489735015971904 Mon Sep 17 00:00:00 2001 From: Kyle Spearrin Date: Fri, 25 Jan 2019 00:01:24 -0500 Subject: [PATCH] get request up from cloudflare header --- src/Core/CurrentContext.cs | 25 ++++++++----------- .../Utilities/CurrentContextMiddleware.cs | 4 +-- src/Notifications/NotificationsHub.cs | 8 +++--- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/Core/CurrentContext.cs b/src/Core/CurrentContext.cs index ebf8a22a8d..977d56f7a0 100644 --- a/src/Core/CurrentContext.cs +++ b/src/Core/CurrentContext.cs @@ -12,20 +12,21 @@ namespace Bit.Core { public class CurrentContext { + private const string CloudFlareConnectingIp = "CF-Connecting-IP"; + private bool _builtHttpContext; private bool _builtClaimsPrincipal; - private string _ip; public virtual HttpContext HttpContext { get; set; } public virtual Guid? UserId { get; set; } public virtual User User { get; set; } public virtual string DeviceIdentifier { get; set; } public virtual DeviceType? DeviceType { get; set; } - public virtual string IpAddress => GetRequestIp(); + public virtual string IpAddress { get; set; } public virtual List Organizations { get; set; } public virtual Guid? InstallationId { get; set; } - public void Build(HttpContext httpContext) + public void Build(HttpContext httpContext, GlobalSettings globalSettings) { if(_builtHttpContext) { @@ -34,7 +35,7 @@ namespace Bit.Core _builtHttpContext = true; HttpContext = httpContext; - Build(httpContext.User); + Build(httpContext.User, globalSettings); if(DeviceIdentifier == null && httpContext.Request.Headers.ContainsKey("Device-Identifier")) { @@ -48,7 +49,7 @@ namespace Bit.Core } } - public void Build(ClaimsPrincipal user) + public void Build(ClaimsPrincipal user, GlobalSettings globalSettings) { if(_builtClaimsPrincipal) { @@ -56,6 +57,7 @@ namespace Bit.Core } _builtClaimsPrincipal = true; + IpAddress = GetRequestIp(globalSettings); if(user == null || !user.Claims.Any()) { return; @@ -158,24 +160,19 @@ namespace Bit.Core return Organizations; } - private string GetRequestIp() + private string GetRequestIp(GlobalSettings globalSettings) { - if(!string.IsNullOrWhiteSpace(_ip)) - { - return _ip; - } - if(HttpContext == null) { return null; } - if(string.IsNullOrWhiteSpace(_ip)) + if(!globalSettings.SelfHosted && HttpContext.Request.Headers.ContainsKey(CloudFlareConnectingIp)) { - _ip = HttpContext.Connection?.RemoteIpAddress?.ToString(); + return HttpContext.Request.Headers[CloudFlareConnectingIp].ToString(); } - return _ip; + return HttpContext.Connection?.RemoteIpAddress?.ToString(); } private string GetClaimValue(Dictionary> claims, string type) diff --git a/src/Core/Utilities/CurrentContextMiddleware.cs b/src/Core/Utilities/CurrentContextMiddleware.cs index d6a16a9844..73607df3b4 100644 --- a/src/Core/Utilities/CurrentContextMiddleware.cs +++ b/src/Core/Utilities/CurrentContextMiddleware.cs @@ -12,9 +12,9 @@ namespace Bit.Core.Utilities _next = next; } - public async Task Invoke(HttpContext httpContext, CurrentContext currentContext) + public async Task Invoke(HttpContext httpContext, CurrentContext currentContext, GlobalSettings globalSettings) { - currentContext.Build(httpContext); + currentContext.Build(httpContext, globalSettings); await _next.Invoke(httpContext); } } diff --git a/src/Notifications/NotificationsHub.cs b/src/Notifications/NotificationsHub.cs index 3bfc227c3a..0d14d08ffe 100644 --- a/src/Notifications/NotificationsHub.cs +++ b/src/Notifications/NotificationsHub.cs @@ -9,16 +9,18 @@ namespace Bit.Notifications public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub { private readonly ConnectionCounter _connectionCounter; + private readonly GlobalSettings _globalSettings; - public NotificationsHub(ConnectionCounter connectionCounter) + public NotificationsHub(ConnectionCounter connectionCounter, GlobalSettings globalSettings) { _connectionCounter = connectionCounter; + _globalSettings = globalSettings; } public override async Task OnConnectedAsync() { var currentContext = new CurrentContext(); - currentContext.Build(Context.User); + currentContext.Build(Context.User, _globalSettings); if(currentContext.Organizations != null) { foreach(var org in currentContext.Organizations) @@ -33,7 +35,7 @@ namespace Bit.Notifications public override async Task OnDisconnectedAsync(Exception exception) { var currentContext = new CurrentContext(); - currentContext.Build(Context.User); + currentContext.Build(Context.User, _globalSettings); if(currentContext.Organizations != null) { foreach(var org in currentContext.Organizations)