From 2bf8438ff711c76addd3fcd1b1980e4e50e39048 Mon Sep 17 00:00:00 2001 From: Shane Melton Date: Wed, 31 Aug 2022 14:17:29 -0700 Subject: [PATCH] [EC-502] Rate Limiting Improvements (#2231) * [EC-502] Add custom Redis IP rate limit processing strategy * [EC-502] Formatting * [EC-502] Add documentation and app setting config options * [EC-502] Formatting * [EC-502] Fix appsettings.json keys * [EC-502] Replace magic string for cache key * [EC-502] Add tests for custom processing strategy * [EC-502] Formatting * [EC-502] Use base class for custom processing strategy * [EC-502] Fix failing test --- src/Api/appsettings.json | 5 + src/Core/Settings/GlobalSettings.cs | 21 ++ .../CustomRedisProcessingStrategy.cs | 102 +++++++++ src/Identity/appsettings.json | 5 + .../Utilities/ServiceCollectionExtensions.cs | 12 +- .../CustomRedisProcessingStrategyTests.cs | 198 ++++++++++++++++++ 6 files changed, 340 insertions(+), 3 deletions(-) create mode 100644 src/Core/Utilities/CustomRedisProcessingStrategy.cs create mode 100644 test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs diff --git a/src/Api/appsettings.json b/src/Api/appsettings.json index 92922d104f..fde1db479e 100644 --- a/src/Api/appsettings.json +++ b/src/Api/appsettings.json @@ -69,6 +69,11 @@ "accessKeyId": "SECRET", "accessKeySecret": "SECRET", "region": "SECRET" + }, + "distributedIpRateLimiting": { + "enabled": true, + "maxRedisTimeoutsThreshold": 10, + "slidingWindowSeconds": 120 } }, "IpRateLimitOptions": { diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index bd4087f3a9..7bb66377de 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -69,6 +69,8 @@ public class GlobalSettings : IGlobalSettings public virtual ISsoSettings Sso { get; set; } = new SsoSettings(); public virtual StripeSettings Stripe { get; set; } = new StripeSettings(); public virtual ITwoFactorAuthSettings TwoFactorAuth { get; set; } = new TwoFactorAuthSettings(); + public virtual DistributedIpRateLimitingSettings DistributedIpRateLimiting { get; set; } = + new DistributedIpRateLimitingSettings(); public string BuildExternalUri(string explicitValue, string name) { @@ -498,4 +500,23 @@ public class GlobalSettings : IGlobalSettings { public bool EmailOnNewDeviceLogin { get; set; } = false; } + + public class DistributedIpRateLimitingSettings + { + public bool Enabled { get; set; } = true; + + /// + /// Maximum number of Redis timeouts that can be experienced within the sliding timeout + /// window before IP rate limiting is temporarily disabled. + /// TODO: Determine/discuss a suitable maximum + /// + public int MaxRedisTimeoutsThreshold { get; set; } = 10; + + /// + /// Length of the sliding window in seconds to track Redis timeout exceptions. + /// TODO: Determine/discuss a suitable sliding window + /// + public int SlidingWindowSeconds { get; set; } = 120; + } + } diff --git a/src/Core/Utilities/CustomRedisProcessingStrategy.cs b/src/Core/Utilities/CustomRedisProcessingStrategy.cs new file mode 100644 index 0000000000..e3b4acd68e --- /dev/null +++ b/src/Core/Utilities/CustomRedisProcessingStrategy.cs @@ -0,0 +1,102 @@ +using AspNetCoreRateLimit; +using AspNetCoreRateLimit.Redis; +using Bit.Core.Settings; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Logging; +using StackExchange.Redis; + +namespace Bit.Core.Utilities; + +/// +/// A modified version of that gracefully +/// handles a disrupted Redis connection. If the connection is down or the number of failed requests within +/// a given time period exceed the configured threshold, then rate limiting is temporarily disabled. +/// +/// +/// This is necessary to ensure the service does not become unresponsive due to Redis being out of service. As +/// the default implementation would throw an exception and exit the request pipeline for all requests. +/// +public class CustomRedisProcessingStrategy : RedisProcessingStrategy +{ + private readonly IConnectionMultiplexer _connectionMultiplexer; + private readonly ILogger _logger; + private readonly IMemoryCache _memoryCache; + private readonly GlobalSettings.DistributedIpRateLimitingSettings _distributedSettings; + + private const string _redisTimeoutCacheKey = "IpRateLimitRedisTimeout"; + + public CustomRedisProcessingStrategy( + IConnectionMultiplexer connectionMultiplexer, + IRateLimitConfiguration config, + ILogger logger, + IMemoryCache memoryCache, + GlobalSettings globalSettings) + : base(connectionMultiplexer, config, logger) + { + _connectionMultiplexer = connectionMultiplexer; + _logger = logger; + _memoryCache = memoryCache; + _distributedSettings = globalSettings.DistributedIpRateLimiting; + } + + public override async Task ProcessRequestAsync(ClientRequestIdentity requestIdentity, + RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions, + CancellationToken cancellationToken = default) + { + // If Redis is down entirely, skip rate limiting + if (!_connectionMultiplexer.IsConnected) + { + _logger.LogDebug("Redis connection is down, skipping IP rate limiting"); + return SkipRateLimitResult(); + } + + // Check if any Redis timeouts have occured recently + if (_memoryCache.TryGetValue(_redisTimeoutCacheKey, out var timeoutCounter)) + { + // We've exceeded threshold, backoff Redis and skip rate limiting for now + if (timeoutCounter.Count >= _distributedSettings.MaxRedisTimeoutsThreshold) + { + _logger.LogDebug( + "Redis timeout threshold has been exceeded, backing off and skipping IP rate limiting"); + return SkipRateLimitResult(); + } + } + + try + { + return await base.ProcessRequestAsync(requestIdentity, rule, counterKeyBuilder, rateLimitOptions, cancellationToken); + } + catch (RedisTimeoutException) + { + // If this is the first timeout we've had, start a new counter and sliding window + timeoutCounter ??= new TimeoutCounter() + { + Count = 0, + ExpiresAt = DateTime.UtcNow.AddSeconds(_distributedSettings.SlidingWindowSeconds) + }; + timeoutCounter.Count++; + + _memoryCache.Set(_redisTimeoutCacheKey, timeoutCounter, + new MemoryCacheEntryOptions { AbsoluteExpiration = timeoutCounter.ExpiresAt }); + + // Just because Redis timed out does not mean we should kill the request + return SkipRateLimitResult(); + } + } + + /// + /// A RateLimitCounter result used when the rate limiting middleware should + /// fail open and allow the request to proceed without checking request limits. + /// + private static RateLimitCounter SkipRateLimitResult() + { + return new RateLimitCounter { Count = 0, Timestamp = DateTime.UtcNow }; + } + + internal class TimeoutCounter + { + public DateTime ExpiresAt { get; init; } + + public int Count { get; set; } + } +} diff --git a/src/Identity/appsettings.json b/src/Identity/appsettings.json index fb3469b5fe..609a5004aa 100644 --- a/src/Identity/appsettings.json +++ b/src/Identity/appsettings.json @@ -59,6 +59,11 @@ "accessKeyId": "SECRET", "accessKeySecret": "SECRET", "region": "SECRET" + }, + "distributedIpRateLimiting": { + "enabled": true, + "maxRedisTimeoutsThreshold": 10, + "slidingWindowSeconds": 120 } }, "IpRateLimitOptions": { diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index b2efe511d0..7fab4b2eed 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -2,7 +2,6 @@ using System.Security.Claims; using System.Security.Cryptography.X509Certificates; using AspNetCoreRateLimit; -using AspNetCoreRateLimit.Redis; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.HostedServices; @@ -609,13 +608,20 @@ public static class ServiceCollectionExtensions services.AddHostedService(); services.AddSingleton(); - if (string.IsNullOrEmpty(globalSettings.Redis.ConnectionString)) + if (!globalSettings.DistributedIpRateLimiting.Enabled || string.IsNullOrEmpty(globalSettings.Redis.ConnectionString)) { services.AddInMemoryRateLimiting(); } else { - services.AddRedisRateLimiting(); // Requires a registered IConnectionMultiplexer + // Use memory stores for Ip and Client Policy stores as we don't currently use them + // and they add unnecessary Redis network delays checking for policies that don't exist + services.AddSingleton(); + services.AddSingleton(); + + // Use a custom Redis processing strategy that skips Ip limiting if Redis is down + // Requires a registered IConnectionMultiplexer + services.AddSingleton(); } } diff --git a/test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs b/test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs new file mode 100644 index 0000000000..e5b9bd5549 --- /dev/null +++ b/test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs @@ -0,0 +1,198 @@ +using AspNetCoreRateLimit; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Logging; +using Moq; +using StackExchange.Redis; +using Xunit; + +namespace Bit.Core.Test.Utilities; + +public class CustomRedisProcessingStrategyTests +{ + #region Sample RateLimit Options for Testing + + private readonly GlobalSettings _sampleSettings = new() + { + DistributedIpRateLimiting = new GlobalSettings.DistributedIpRateLimitingSettings + { + Enabled = true, + MaxRedisTimeoutsThreshold = 2, + SlidingWindowSeconds = 5 + } + }; + + private readonly ClientRequestIdentity _sampleClientId = new() + { + ClientId = "test", + ClientIp = "127.0.0.1", + HttpVerb = "GET", + Path = "/" + }; + + private readonly RateLimitRule _sampleRule = new() { Endpoint = "/", Limit = 5, Period = "1m", PeriodTimespan = TimeSpan.FromMinutes(1) }; + + private readonly RateLimitOptions _sampleOptions = new() { }; + + #endregion + + private readonly Mock _mockCounterKeyBuilder = new(); + private Mock _mockDb; + + public CustomRedisProcessingStrategyTests() + { + _mockCounterKeyBuilder + .Setup(x => + x.Build(It.IsAny(), It.IsAny())) + .Returns(_sampleClientId.ClientId); + } + + [Fact] + public async Task IncrementRateLimitCount_When_RedisIsHealthy() + { + // Arrange + var strategy = BuildProcessingStrategy(); + + // Act + var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, + CancellationToken.None); + + // Assert + Assert.Equal(1, result.Count); + VerifyRedisCalls(Times.Once()); + } + + [Fact] + public async Task SkipRateLimit_When_RedisIsDown() + { + // Arrange + var strategy = BuildProcessingStrategy(false); + + // Act + var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, + CancellationToken.None); + + // Assert + Assert.Equal(0, result.Count); + VerifyRedisCalls(Times.Never()); + } + + [Fact] + public async Task SkipRateLimit_When_TimeoutThresholdExceeded() + { + // Arrange + var mockCache = new Mock(); + object existingCount = new CustomRedisProcessingStrategy.TimeoutCounter + { + Count = _sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold + 1 + }; + mockCache.Setup(x => x.TryGetValue(It.IsAny(), out existingCount)).Returns(true); + + var strategy = BuildProcessingStrategy(mockCache: mockCache.Object); + + // Act + var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, + CancellationToken.None); + + // Assert + Assert.Equal(0, result.Count); + VerifyRedisCalls(Times.Never()); + } + + [Fact] + public async Task SkipRateLimit_When_RedisTimeoutException() + { + // Arrange + var mockCache = new Mock(); + var mockCacheEntry = new Mock(); + mockCacheEntry.SetupAllProperties(); + mockCache.Setup(x => x.CreateEntry(It.IsAny())).Returns(mockCacheEntry.Object); + + var strategy = BuildProcessingStrategy(mockCache: mockCache.Object, throwRedisTimeout: true); + + // Act + var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, + CancellationToken.None); + + var timeoutCounter = ((CustomRedisProcessingStrategy.TimeoutCounter)mockCacheEntry.Object.Value); + + // Assert + Assert.Equal(0, result.Count); // Skip rate limiting + VerifyRedisCalls(Times.Once()); + + Assert.Equal(1, timeoutCounter.Count); // Timeout count increased/cached + Assert.NotNull(mockCacheEntry.Object.AbsoluteExpiration); + mockCache.Verify(x => x.CreateEntry(It.IsAny())); + } + + [Fact] + public async Task BackoffRedis_After_ThresholdExceeded() + { + // Arrange + var memoryCache = new MemoryCache(new MemoryCacheOptions()); + var strategy = BuildProcessingStrategy(mockCache: memoryCache, throwRedisTimeout: true); + + // Act + + // Redis Timeout 1 + await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, + CancellationToken.None); + + // Redis Timeout 2 + await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, + CancellationToken.None); + + // Skip Redis + await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions, + CancellationToken.None); + + // Assert + VerifyRedisCalls(Times.Exactly(_sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold)); + } + + private void VerifyRedisCalls(Times times) + { + _mockDb.Verify(x => + x.ScriptEvaluateAsync(It.IsAny(), It.IsAny(), It.IsAny()), + times); + } + + private CustomRedisProcessingStrategy BuildProcessingStrategy( + bool isRedisConnected = true, + bool throwRedisTimeout = false, + IMemoryCache mockCache = null) + { + var mockRedisConnection = new Mock(); + + mockRedisConnection.Setup(x => x.IsConnected).Returns(isRedisConnected); + + _mockDb = new Mock(); + + var mockScriptEvaluate = _mockDb + .Setup(x => + x.ScriptEvaluateAsync(It.IsAny(), It.IsAny(), It.IsAny())); + + if (throwRedisTimeout) + { + mockScriptEvaluate.ThrowsAsync(new RedisTimeoutException("Timeout", CommandStatus.WaitingToBeSent)); + } + else + { + mockScriptEvaluate.ReturnsAsync(RedisResult.Create(1)); + } + + mockRedisConnection + .Setup(x => + x.GetDatabase(It.IsAny(), It.IsAny())) + .Returns(_mockDb.Object); + + var mockLogger = new Mock>(); + var mockConfig = new Mock(); + + mockCache ??= new Mock().Object; + + return new CustomRedisProcessingStrategy(mockRedisConnection.Object, mockConfig.Object, + mockLogger.Object, mockCache, _sampleSettings); + } +}