diff --git a/src/Api/appsettings.json b/src/Api/appsettings.json
index 92922d104f..fde1db479e 100644
--- a/src/Api/appsettings.json
+++ b/src/Api/appsettings.json
@@ -69,6 +69,11 @@
"accessKeyId": "SECRET",
"accessKeySecret": "SECRET",
"region": "SECRET"
+ },
+ "distributedIpRateLimiting": {
+ "enabled": true,
+ "maxRedisTimeoutsThreshold": 10,
+ "slidingWindowSeconds": 120
}
},
"IpRateLimitOptions": {
diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs
index bd4087f3a9..7bb66377de 100644
--- a/src/Core/Settings/GlobalSettings.cs
+++ b/src/Core/Settings/GlobalSettings.cs
@@ -69,6 +69,8 @@ public class GlobalSettings : IGlobalSettings
public virtual ISsoSettings Sso { get; set; } = new SsoSettings();
public virtual StripeSettings Stripe { get; set; } = new StripeSettings();
public virtual ITwoFactorAuthSettings TwoFactorAuth { get; set; } = new TwoFactorAuthSettings();
+ public virtual DistributedIpRateLimitingSettings DistributedIpRateLimiting { get; set; } =
+ new DistributedIpRateLimitingSettings();
public string BuildExternalUri(string explicitValue, string name)
{
@@ -498,4 +500,23 @@ public class GlobalSettings : IGlobalSettings
{
public bool EmailOnNewDeviceLogin { get; set; } = false;
}
+
+ public class DistributedIpRateLimitingSettings
+ {
+ public bool Enabled { get; set; } = true;
+
+ ///
+ /// Maximum number of Redis timeouts that can be experienced within the sliding timeout
+ /// window before IP rate limiting is temporarily disabled.
+ /// TODO: Determine/discuss a suitable maximum
+ ///
+ public int MaxRedisTimeoutsThreshold { get; set; } = 10;
+
+ ///
+ /// Length of the sliding window in seconds to track Redis timeout exceptions.
+ /// TODO: Determine/discuss a suitable sliding window
+ ///
+ public int SlidingWindowSeconds { get; set; } = 120;
+ }
+
}
diff --git a/src/Core/Utilities/CustomRedisProcessingStrategy.cs b/src/Core/Utilities/CustomRedisProcessingStrategy.cs
new file mode 100644
index 0000000000..e3b4acd68e
--- /dev/null
+++ b/src/Core/Utilities/CustomRedisProcessingStrategy.cs
@@ -0,0 +1,102 @@
+using AspNetCoreRateLimit;
+using AspNetCoreRateLimit.Redis;
+using Bit.Core.Settings;
+using Microsoft.Extensions.Caching.Memory;
+using Microsoft.Extensions.Logging;
+using StackExchange.Redis;
+
+namespace Bit.Core.Utilities;
+
+///
+/// A modified version of that gracefully
+/// handles a disrupted Redis connection. If the connection is down or the number of failed requests within
+/// a given time period exceed the configured threshold, then rate limiting is temporarily disabled.
+///
+///
+/// This is necessary to ensure the service does not become unresponsive due to Redis being out of service. As
+/// the default implementation would throw an exception and exit the request pipeline for all requests.
+///
+public class CustomRedisProcessingStrategy : RedisProcessingStrategy
+{
+ private readonly IConnectionMultiplexer _connectionMultiplexer;
+ private readonly ILogger _logger;
+ private readonly IMemoryCache _memoryCache;
+ private readonly GlobalSettings.DistributedIpRateLimitingSettings _distributedSettings;
+
+ private const string _redisTimeoutCacheKey = "IpRateLimitRedisTimeout";
+
+ public CustomRedisProcessingStrategy(
+ IConnectionMultiplexer connectionMultiplexer,
+ IRateLimitConfiguration config,
+ ILogger logger,
+ IMemoryCache memoryCache,
+ GlobalSettings globalSettings)
+ : base(connectionMultiplexer, config, logger)
+ {
+ _connectionMultiplexer = connectionMultiplexer;
+ _logger = logger;
+ _memoryCache = memoryCache;
+ _distributedSettings = globalSettings.DistributedIpRateLimiting;
+ }
+
+ public override async Task ProcessRequestAsync(ClientRequestIdentity requestIdentity,
+ RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions,
+ CancellationToken cancellationToken = default)
+ {
+ // If Redis is down entirely, skip rate limiting
+ if (!_connectionMultiplexer.IsConnected)
+ {
+ _logger.LogDebug("Redis connection is down, skipping IP rate limiting");
+ return SkipRateLimitResult();
+ }
+
+ // Check if any Redis timeouts have occured recently
+ if (_memoryCache.TryGetValue(_redisTimeoutCacheKey, out var timeoutCounter))
+ {
+ // We've exceeded threshold, backoff Redis and skip rate limiting for now
+ if (timeoutCounter.Count >= _distributedSettings.MaxRedisTimeoutsThreshold)
+ {
+ _logger.LogDebug(
+ "Redis timeout threshold has been exceeded, backing off and skipping IP rate limiting");
+ return SkipRateLimitResult();
+ }
+ }
+
+ try
+ {
+ return await base.ProcessRequestAsync(requestIdentity, rule, counterKeyBuilder, rateLimitOptions, cancellationToken);
+ }
+ catch (RedisTimeoutException)
+ {
+ // If this is the first timeout we've had, start a new counter and sliding window
+ timeoutCounter ??= new TimeoutCounter()
+ {
+ Count = 0,
+ ExpiresAt = DateTime.UtcNow.AddSeconds(_distributedSettings.SlidingWindowSeconds)
+ };
+ timeoutCounter.Count++;
+
+ _memoryCache.Set(_redisTimeoutCacheKey, timeoutCounter,
+ new MemoryCacheEntryOptions { AbsoluteExpiration = timeoutCounter.ExpiresAt });
+
+ // Just because Redis timed out does not mean we should kill the request
+ return SkipRateLimitResult();
+ }
+ }
+
+ ///
+ /// A RateLimitCounter result used when the rate limiting middleware should
+ /// fail open and allow the request to proceed without checking request limits.
+ ///
+ private static RateLimitCounter SkipRateLimitResult()
+ {
+ return new RateLimitCounter { Count = 0, Timestamp = DateTime.UtcNow };
+ }
+
+ internal class TimeoutCounter
+ {
+ public DateTime ExpiresAt { get; init; }
+
+ public int Count { get; set; }
+ }
+}
diff --git a/src/Identity/appsettings.json b/src/Identity/appsettings.json
index fb3469b5fe..609a5004aa 100644
--- a/src/Identity/appsettings.json
+++ b/src/Identity/appsettings.json
@@ -59,6 +59,11 @@
"accessKeyId": "SECRET",
"accessKeySecret": "SECRET",
"region": "SECRET"
+ },
+ "distributedIpRateLimiting": {
+ "enabled": true,
+ "maxRedisTimeoutsThreshold": 10,
+ "slidingWindowSeconds": 120
}
},
"IpRateLimitOptions": {
diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs
index b2efe511d0..7fab4b2eed 100644
--- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs
+++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs
@@ -2,7 +2,6 @@
using System.Security.Claims;
using System.Security.Cryptography.X509Certificates;
using AspNetCoreRateLimit;
-using AspNetCoreRateLimit.Redis;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.HostedServices;
@@ -609,13 +608,20 @@ public static class ServiceCollectionExtensions
services.AddHostedService();
services.AddSingleton();
- if (string.IsNullOrEmpty(globalSettings.Redis.ConnectionString))
+ if (!globalSettings.DistributedIpRateLimiting.Enabled || string.IsNullOrEmpty(globalSettings.Redis.ConnectionString))
{
services.AddInMemoryRateLimiting();
}
else
{
- services.AddRedisRateLimiting(); // Requires a registered IConnectionMultiplexer
+ // Use memory stores for Ip and Client Policy stores as we don't currently use them
+ // and they add unnecessary Redis network delays checking for policies that don't exist
+ services.AddSingleton();
+ services.AddSingleton();
+
+ // Use a custom Redis processing strategy that skips Ip limiting if Redis is down
+ // Requires a registered IConnectionMultiplexer
+ services.AddSingleton();
}
}
diff --git a/test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs b/test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs
new file mode 100644
index 0000000000..e5b9bd5549
--- /dev/null
+++ b/test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs
@@ -0,0 +1,198 @@
+using AspNetCoreRateLimit;
+using Bit.Core.Settings;
+using Bit.Core.Utilities;
+using Microsoft.Extensions.Caching.Memory;
+using Microsoft.Extensions.Logging;
+using Moq;
+using StackExchange.Redis;
+using Xunit;
+
+namespace Bit.Core.Test.Utilities;
+
+public class CustomRedisProcessingStrategyTests
+{
+ #region Sample RateLimit Options for Testing
+
+ private readonly GlobalSettings _sampleSettings = new()
+ {
+ DistributedIpRateLimiting = new GlobalSettings.DistributedIpRateLimitingSettings
+ {
+ Enabled = true,
+ MaxRedisTimeoutsThreshold = 2,
+ SlidingWindowSeconds = 5
+ }
+ };
+
+ private readonly ClientRequestIdentity _sampleClientId = new()
+ {
+ ClientId = "test",
+ ClientIp = "127.0.0.1",
+ HttpVerb = "GET",
+ Path = "/"
+ };
+
+ private readonly RateLimitRule _sampleRule = new() { Endpoint = "/", Limit = 5, Period = "1m", PeriodTimespan = TimeSpan.FromMinutes(1) };
+
+ private readonly RateLimitOptions _sampleOptions = new() { };
+
+ #endregion
+
+ private readonly Mock _mockCounterKeyBuilder = new();
+ private Mock _mockDb;
+
+ public CustomRedisProcessingStrategyTests()
+ {
+ _mockCounterKeyBuilder
+ .Setup(x =>
+ x.Build(It.IsAny(), It.IsAny()))
+ .Returns(_sampleClientId.ClientId);
+ }
+
+ [Fact]
+ public async Task IncrementRateLimitCount_When_RedisIsHealthy()
+ {
+ // Arrange
+ var strategy = BuildProcessingStrategy();
+
+ // Act
+ var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
+ CancellationToken.None);
+
+ // Assert
+ Assert.Equal(1, result.Count);
+ VerifyRedisCalls(Times.Once());
+ }
+
+ [Fact]
+ public async Task SkipRateLimit_When_RedisIsDown()
+ {
+ // Arrange
+ var strategy = BuildProcessingStrategy(false);
+
+ // Act
+ var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
+ CancellationToken.None);
+
+ // Assert
+ Assert.Equal(0, result.Count);
+ VerifyRedisCalls(Times.Never());
+ }
+
+ [Fact]
+ public async Task SkipRateLimit_When_TimeoutThresholdExceeded()
+ {
+ // Arrange
+ var mockCache = new Mock();
+ object existingCount = new CustomRedisProcessingStrategy.TimeoutCounter
+ {
+ Count = _sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold + 1
+ };
+ mockCache.Setup(x => x.TryGetValue(It.IsAny