diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 416bcaaebc..1003a65b51 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -143,6 +143,7 @@ public static class FeatureFlagKeys public const string StorageReseedRefactor = "storage-reseed-refactor"; public const string TrialPayment = "PM-8163-trial-payment"; public const string Pm3478RefactorOrganizationUserApi = "pm-3478-refactor-organizationuser-api"; + public const string RemoveServerVersionHeader = "remove-server-version-header"; public static List GetAllKeys() { diff --git a/src/SharedWeb/Utilities/RequestLoggingMiddleware.cs b/src/SharedWeb/Utilities/RequestLoggingMiddleware.cs new file mode 100644 index 0000000000..4fb0e8f92e --- /dev/null +++ b/src/SharedWeb/Utilities/RequestLoggingMiddleware.cs @@ -0,0 +1,117 @@ +using System.Collections; +using Bit.Core; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; + +#nullable enable + +namespace Bit.SharedWeb.Utilities; + +public sealed class RequestLoggingMiddleware +{ + private readonly RequestDelegate _next; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + + public RequestLoggingMiddleware(RequestDelegate next, ILogger logger, GlobalSettings globalSettings) + { + _next = next; + _logger = logger; + _globalSettings = globalSettings; + } + + public Task Invoke(HttpContext context, IFeatureService featureService) + { + if (!featureService.IsEnabled(FeatureFlagKeys.RemoveServerVersionHeader)) + { + context.Response.OnStarting(() => + { + context.Response.Headers.Append("Server-Version", AssemblyHelpers.GetVersion()); + return Task.CompletedTask; + }); + } + + using (_logger.BeginScope( + new RequestLogScope(context.GetIpAddress(_globalSettings), + GetHeaderValue(context, "user-agent"), + GetHeaderValue(context, "device-type"), + GetHeaderValue(context, "device-type")))) + { + return _next(context); + } + + static string? GetHeaderValue(HttpContext httpContext, string header) + { + if (httpContext.Request.Headers.TryGetValue(header, out var value)) + { + return value; + } + + return null; + } + } + + + private sealed class RequestLogScope : IReadOnlyList> + { + private string? _cachedToString; + + public RequestLogScope(string? ipAddress, string? userAgent, string? deviceType, string? origin) + { + IpAddress = ipAddress; + UserAgent = userAgent; + DeviceType = deviceType; + Origin = origin; + } + + public KeyValuePair this[int index] + { + get + { + if (index == 0) + { + return new KeyValuePair(nameof(IpAddress), IpAddress); + } + else if (index == 1) + { + return new KeyValuePair(nameof(UserAgent), UserAgent); + } + else if (index == 2) + { + return new KeyValuePair(nameof(DeviceType), DeviceType); + } + else if (index == 3) + { + return new KeyValuePair(nameof(Origin), Origin); + } + + throw new ArgumentOutOfRangeException(nameof(index)); + } + } + + public int Count => 4; + + public string? IpAddress { get; } + public string? UserAgent { get; } + public string? DeviceType { get; } + public string? Origin { get; } + + public IEnumerator> GetEnumerator() + { + for (var i = 0; i < Count; i++) + { + yield return this[i]; + } + } + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public override string ToString() + { + _cachedToString ??= $"IpAddress:{IpAddress} UserAgent:{UserAgent} DeviceType:{DeviceType} Origin:{Origin}"; + return _cachedToString; + } + } +} diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index be451ea318..bd3aecf2f5 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -48,7 +48,6 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.HttpOverrides; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc.Localization; @@ -60,7 +59,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using Serilog.Context; using StackExchange.Redis; using NoopRepos = Bit.Core.Repositories.Noop; using Role = Bit.Core.Entities.Role; @@ -540,31 +538,7 @@ public static class ServiceCollectionExtensions public static void UseDefaultMiddleware(this IApplicationBuilder app, IWebHostEnvironment env, GlobalSettings globalSettings) { - string GetHeaderValue(HttpContext httpContext, string header) - { - if (httpContext.Request.Headers.ContainsKey(header)) - { - return httpContext.Request.Headers[header]; - } - return null; - } - - // Add version information to response headers - app.Use(async (httpContext, next) => - { - using (LogContext.PushProperty("IPAddress", httpContext.GetIpAddress(globalSettings))) - using (LogContext.PushProperty("UserAgent", GetHeaderValue(httpContext, "user-agent"))) - using (LogContext.PushProperty("DeviceType", GetHeaderValue(httpContext, "device-type"))) - using (LogContext.PushProperty("Origin", GetHeaderValue(httpContext, "origin"))) - { - httpContext.Response.OnStarting((state) => - { - httpContext.Response.Headers.Append("Server-Version", AssemblyHelpers.GetVersion()); - return Task.FromResult(0); - }, null); - await next.Invoke(); - } - }); + app.UseMiddleware(); } public static void UseForwardedHeaders(this IApplicationBuilder app, IGlobalSettings globalSettings)