mirror of
https://github.com/bitwarden/server.git
synced 2025-07-01 16:12:49 -05:00
[EC-502] Rate Limiting Improvements (#2231)
* [EC-502] Add custom Redis IP rate limit processing strategy * [EC-502] Formatting * [EC-502] Add documentation and app setting config options * [EC-502] Formatting * [EC-502] Fix appsettings.json keys * [EC-502] Replace magic string for cache key * [EC-502] Add tests for custom processing strategy * [EC-502] Formatting * [EC-502] Use base class for custom processing strategy * [EC-502] Fix failing test
This commit is contained in:
@ -69,6 +69,11 @@
|
|||||||
"accessKeyId": "SECRET",
|
"accessKeyId": "SECRET",
|
||||||
"accessKeySecret": "SECRET",
|
"accessKeySecret": "SECRET",
|
||||||
"region": "SECRET"
|
"region": "SECRET"
|
||||||
|
},
|
||||||
|
"distributedIpRateLimiting": {
|
||||||
|
"enabled": true,
|
||||||
|
"maxRedisTimeoutsThreshold": 10,
|
||||||
|
"slidingWindowSeconds": 120
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"IpRateLimitOptions": {
|
"IpRateLimitOptions": {
|
||||||
|
@ -69,6 +69,8 @@ public class GlobalSettings : IGlobalSettings
|
|||||||
public virtual ISsoSettings Sso { get; set; } = new SsoSettings();
|
public virtual ISsoSettings Sso { get; set; } = new SsoSettings();
|
||||||
public virtual StripeSettings Stripe { get; set; } = new StripeSettings();
|
public virtual StripeSettings Stripe { get; set; } = new StripeSettings();
|
||||||
public virtual ITwoFactorAuthSettings TwoFactorAuth { get; set; } = new TwoFactorAuthSettings();
|
public virtual ITwoFactorAuthSettings TwoFactorAuth { get; set; } = new TwoFactorAuthSettings();
|
||||||
|
public virtual DistributedIpRateLimitingSettings DistributedIpRateLimiting { get; set; } =
|
||||||
|
new DistributedIpRateLimitingSettings();
|
||||||
|
|
||||||
public string BuildExternalUri(string explicitValue, string name)
|
public string BuildExternalUri(string explicitValue, string name)
|
||||||
{
|
{
|
||||||
@ -498,4 +500,23 @@ public class GlobalSettings : IGlobalSettings
|
|||||||
{
|
{
|
||||||
public bool EmailOnNewDeviceLogin { get; set; } = false;
|
public bool EmailOnNewDeviceLogin { get; set; } = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class DistributedIpRateLimitingSettings
|
||||||
|
{
|
||||||
|
public bool Enabled { get; set; } = true;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Maximum number of Redis timeouts that can be experienced within the sliding timeout
|
||||||
|
/// window before IP rate limiting is temporarily disabled.
|
||||||
|
/// TODO: Determine/discuss a suitable maximum
|
||||||
|
/// </summary>
|
||||||
|
public int MaxRedisTimeoutsThreshold { get; set; } = 10;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Length of the sliding window in seconds to track Redis timeout exceptions.
|
||||||
|
/// TODO: Determine/discuss a suitable sliding window
|
||||||
|
/// </summary>
|
||||||
|
public int SlidingWindowSeconds { get; set; } = 120;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
102
src/Core/Utilities/CustomRedisProcessingStrategy.cs
Normal file
102
src/Core/Utilities/CustomRedisProcessingStrategy.cs
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
using AspNetCoreRateLimit;
|
||||||
|
using AspNetCoreRateLimit.Redis;
|
||||||
|
using Bit.Core.Settings;
|
||||||
|
using Microsoft.Extensions.Caching.Memory;
|
||||||
|
using Microsoft.Extensions.Logging;
|
||||||
|
using StackExchange.Redis;
|
||||||
|
|
||||||
|
namespace Bit.Core.Utilities;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// A modified version of <see cref="AspNetCoreRateLimit.Redis.RedisProcessingStrategy"/> that gracefully
|
||||||
|
/// handles a disrupted Redis connection. If the connection is down or the number of failed requests within
|
||||||
|
/// a given time period exceed the configured threshold, then rate limiting is temporarily disabled.
|
||||||
|
/// </summary>
|
||||||
|
/// <remarks>
|
||||||
|
/// This is necessary to ensure the service does not become unresponsive due to Redis being out of service. As
|
||||||
|
/// the default implementation would throw an exception and exit the request pipeline for all requests.
|
||||||
|
/// </remarks>
|
||||||
|
public class CustomRedisProcessingStrategy : RedisProcessingStrategy
|
||||||
|
{
|
||||||
|
private readonly IConnectionMultiplexer _connectionMultiplexer;
|
||||||
|
private readonly ILogger<CustomRedisProcessingStrategy> _logger;
|
||||||
|
private readonly IMemoryCache _memoryCache;
|
||||||
|
private readonly GlobalSettings.DistributedIpRateLimitingSettings _distributedSettings;
|
||||||
|
|
||||||
|
private const string _redisTimeoutCacheKey = "IpRateLimitRedisTimeout";
|
||||||
|
|
||||||
|
public CustomRedisProcessingStrategy(
|
||||||
|
IConnectionMultiplexer connectionMultiplexer,
|
||||||
|
IRateLimitConfiguration config,
|
||||||
|
ILogger<CustomRedisProcessingStrategy> logger,
|
||||||
|
IMemoryCache memoryCache,
|
||||||
|
GlobalSettings globalSettings)
|
||||||
|
: base(connectionMultiplexer, config, logger)
|
||||||
|
{
|
||||||
|
_connectionMultiplexer = connectionMultiplexer;
|
||||||
|
_logger = logger;
|
||||||
|
_memoryCache = memoryCache;
|
||||||
|
_distributedSettings = globalSettings.DistributedIpRateLimiting;
|
||||||
|
}
|
||||||
|
|
||||||
|
public override async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity,
|
||||||
|
RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions,
|
||||||
|
CancellationToken cancellationToken = default)
|
||||||
|
{
|
||||||
|
// If Redis is down entirely, skip rate limiting
|
||||||
|
if (!_connectionMultiplexer.IsConnected)
|
||||||
|
{
|
||||||
|
_logger.LogDebug("Redis connection is down, skipping IP rate limiting");
|
||||||
|
return SkipRateLimitResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if any Redis timeouts have occured recently
|
||||||
|
if (_memoryCache.TryGetValue<TimeoutCounter>(_redisTimeoutCacheKey, out var timeoutCounter))
|
||||||
|
{
|
||||||
|
// We've exceeded threshold, backoff Redis and skip rate limiting for now
|
||||||
|
if (timeoutCounter.Count >= _distributedSettings.MaxRedisTimeoutsThreshold)
|
||||||
|
{
|
||||||
|
_logger.LogDebug(
|
||||||
|
"Redis timeout threshold has been exceeded, backing off and skipping IP rate limiting");
|
||||||
|
return SkipRateLimitResult();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
return await base.ProcessRequestAsync(requestIdentity, rule, counterKeyBuilder, rateLimitOptions, cancellationToken);
|
||||||
|
}
|
||||||
|
catch (RedisTimeoutException)
|
||||||
|
{
|
||||||
|
// If this is the first timeout we've had, start a new counter and sliding window
|
||||||
|
timeoutCounter ??= new TimeoutCounter()
|
||||||
|
{
|
||||||
|
Count = 0,
|
||||||
|
ExpiresAt = DateTime.UtcNow.AddSeconds(_distributedSettings.SlidingWindowSeconds)
|
||||||
|
};
|
||||||
|
timeoutCounter.Count++;
|
||||||
|
|
||||||
|
_memoryCache.Set(_redisTimeoutCacheKey, timeoutCounter,
|
||||||
|
new MemoryCacheEntryOptions { AbsoluteExpiration = timeoutCounter.ExpiresAt });
|
||||||
|
|
||||||
|
// Just because Redis timed out does not mean we should kill the request
|
||||||
|
return SkipRateLimitResult();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// A RateLimitCounter result used when the rate limiting middleware should
|
||||||
|
/// fail open and allow the request to proceed without checking request limits.
|
||||||
|
/// </summary>
|
||||||
|
private static RateLimitCounter SkipRateLimitResult()
|
||||||
|
{
|
||||||
|
return new RateLimitCounter { Count = 0, Timestamp = DateTime.UtcNow };
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class TimeoutCounter
|
||||||
|
{
|
||||||
|
public DateTime ExpiresAt { get; init; }
|
||||||
|
|
||||||
|
public int Count { get; set; }
|
||||||
|
}
|
||||||
|
}
|
@ -59,6 +59,11 @@
|
|||||||
"accessKeyId": "SECRET",
|
"accessKeyId": "SECRET",
|
||||||
"accessKeySecret": "SECRET",
|
"accessKeySecret": "SECRET",
|
||||||
"region": "SECRET"
|
"region": "SECRET"
|
||||||
|
},
|
||||||
|
"distributedIpRateLimiting": {
|
||||||
|
"enabled": true,
|
||||||
|
"maxRedisTimeoutsThreshold": 10,
|
||||||
|
"slidingWindowSeconds": 120
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"IpRateLimitOptions": {
|
"IpRateLimitOptions": {
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
using System.Security.Claims;
|
using System.Security.Claims;
|
||||||
using System.Security.Cryptography.X509Certificates;
|
using System.Security.Cryptography.X509Certificates;
|
||||||
using AspNetCoreRateLimit;
|
using AspNetCoreRateLimit;
|
||||||
using AspNetCoreRateLimit.Redis;
|
|
||||||
using Bit.Core.Entities;
|
using Bit.Core.Entities;
|
||||||
using Bit.Core.Enums;
|
using Bit.Core.Enums;
|
||||||
using Bit.Core.HostedServices;
|
using Bit.Core.HostedServices;
|
||||||
@ -609,13 +608,20 @@ public static class ServiceCollectionExtensions
|
|||||||
services.AddHostedService<IpRateLimitSeedStartupService>();
|
services.AddHostedService<IpRateLimitSeedStartupService>();
|
||||||
services.AddSingleton<IRateLimitConfiguration, RateLimitConfiguration>();
|
services.AddSingleton<IRateLimitConfiguration, RateLimitConfiguration>();
|
||||||
|
|
||||||
if (string.IsNullOrEmpty(globalSettings.Redis.ConnectionString))
|
if (!globalSettings.DistributedIpRateLimiting.Enabled || string.IsNullOrEmpty(globalSettings.Redis.ConnectionString))
|
||||||
{
|
{
|
||||||
services.AddInMemoryRateLimiting();
|
services.AddInMemoryRateLimiting();
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
services.AddRedisRateLimiting(); // Requires a registered IConnectionMultiplexer
|
// Use memory stores for Ip and Client Policy stores as we don't currently use them
|
||||||
|
// and they add unnecessary Redis network delays checking for policies that don't exist
|
||||||
|
services.AddSingleton<IIpPolicyStore, MemoryCacheIpPolicyStore>();
|
||||||
|
services.AddSingleton<IClientPolicyStore, MemoryCacheClientPolicyStore>();
|
||||||
|
|
||||||
|
// Use a custom Redis processing strategy that skips Ip limiting if Redis is down
|
||||||
|
// Requires a registered IConnectionMultiplexer
|
||||||
|
services.AddSingleton<IProcessingStrategy, CustomRedisProcessingStrategy>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
198
test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs
Normal file
198
test/Core.Test/Utilities/CustomRedisProcessingStrategyTests.cs
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
using AspNetCoreRateLimit;
|
||||||
|
using Bit.Core.Settings;
|
||||||
|
using Bit.Core.Utilities;
|
||||||
|
using Microsoft.Extensions.Caching.Memory;
|
||||||
|
using Microsoft.Extensions.Logging;
|
||||||
|
using Moq;
|
||||||
|
using StackExchange.Redis;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Bit.Core.Test.Utilities;
|
||||||
|
|
||||||
|
public class CustomRedisProcessingStrategyTests
|
||||||
|
{
|
||||||
|
#region Sample RateLimit Options for Testing
|
||||||
|
|
||||||
|
private readonly GlobalSettings _sampleSettings = new()
|
||||||
|
{
|
||||||
|
DistributedIpRateLimiting = new GlobalSettings.DistributedIpRateLimitingSettings
|
||||||
|
{
|
||||||
|
Enabled = true,
|
||||||
|
MaxRedisTimeoutsThreshold = 2,
|
||||||
|
SlidingWindowSeconds = 5
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
private readonly ClientRequestIdentity _sampleClientId = new()
|
||||||
|
{
|
||||||
|
ClientId = "test",
|
||||||
|
ClientIp = "127.0.0.1",
|
||||||
|
HttpVerb = "GET",
|
||||||
|
Path = "/"
|
||||||
|
};
|
||||||
|
|
||||||
|
private readonly RateLimitRule _sampleRule = new() { Endpoint = "/", Limit = 5, Period = "1m", PeriodTimespan = TimeSpan.FromMinutes(1) };
|
||||||
|
|
||||||
|
private readonly RateLimitOptions _sampleOptions = new() { };
|
||||||
|
|
||||||
|
#endregion
|
||||||
|
|
||||||
|
private readonly Mock<ICounterKeyBuilder> _mockCounterKeyBuilder = new();
|
||||||
|
private Mock<IDatabase> _mockDb;
|
||||||
|
|
||||||
|
public CustomRedisProcessingStrategyTests()
|
||||||
|
{
|
||||||
|
_mockCounterKeyBuilder
|
||||||
|
.Setup(x =>
|
||||||
|
x.Build(It.IsAny<ClientRequestIdentity>(), It.IsAny<RateLimitRule>()))
|
||||||
|
.Returns(_sampleClientId.ClientId);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task IncrementRateLimitCount_When_RedisIsHealthy()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var strategy = BuildProcessingStrategy();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||||
|
CancellationToken.None);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
Assert.Equal(1, result.Count);
|
||||||
|
VerifyRedisCalls(Times.Once());
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task SkipRateLimit_When_RedisIsDown()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var strategy = BuildProcessingStrategy(false);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||||
|
CancellationToken.None);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
Assert.Equal(0, result.Count);
|
||||||
|
VerifyRedisCalls(Times.Never());
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task SkipRateLimit_When_TimeoutThresholdExceeded()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var mockCache = new Mock<IMemoryCache>();
|
||||||
|
object existingCount = new CustomRedisProcessingStrategy.TimeoutCounter
|
||||||
|
{
|
||||||
|
Count = _sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold + 1
|
||||||
|
};
|
||||||
|
mockCache.Setup(x => x.TryGetValue(It.IsAny<object>(), out existingCount)).Returns(true);
|
||||||
|
|
||||||
|
var strategy = BuildProcessingStrategy(mockCache: mockCache.Object);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||||
|
CancellationToken.None);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
Assert.Equal(0, result.Count);
|
||||||
|
VerifyRedisCalls(Times.Never());
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task SkipRateLimit_When_RedisTimeoutException()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var mockCache = new Mock<IMemoryCache>();
|
||||||
|
var mockCacheEntry = new Mock<ICacheEntry>();
|
||||||
|
mockCacheEntry.SetupAllProperties();
|
||||||
|
mockCache.Setup(x => x.CreateEntry(It.IsAny<object>())).Returns(mockCacheEntry.Object);
|
||||||
|
|
||||||
|
var strategy = BuildProcessingStrategy(mockCache: mockCache.Object, throwRedisTimeout: true);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
var result = await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||||
|
CancellationToken.None);
|
||||||
|
|
||||||
|
var timeoutCounter = ((CustomRedisProcessingStrategy.TimeoutCounter)mockCacheEntry.Object.Value);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
Assert.Equal(0, result.Count); // Skip rate limiting
|
||||||
|
VerifyRedisCalls(Times.Once());
|
||||||
|
|
||||||
|
Assert.Equal(1, timeoutCounter.Count); // Timeout count increased/cached
|
||||||
|
Assert.NotNull(mockCacheEntry.Object.AbsoluteExpiration);
|
||||||
|
mockCache.Verify(x => x.CreateEntry(It.IsAny<object>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task BackoffRedis_After_ThresholdExceeded()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var memoryCache = new MemoryCache(new MemoryCacheOptions());
|
||||||
|
var strategy = BuildProcessingStrategy(mockCache: memoryCache, throwRedisTimeout: true);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
|
||||||
|
// Redis Timeout 1
|
||||||
|
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||||
|
CancellationToken.None);
|
||||||
|
|
||||||
|
// Redis Timeout 2
|
||||||
|
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||||
|
CancellationToken.None);
|
||||||
|
|
||||||
|
// Skip Redis
|
||||||
|
await strategy.ProcessRequestAsync(_sampleClientId, _sampleRule, _mockCounterKeyBuilder.Object, _sampleOptions,
|
||||||
|
CancellationToken.None);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
VerifyRedisCalls(Times.Exactly(_sampleSettings.DistributedIpRateLimiting.MaxRedisTimeoutsThreshold));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void VerifyRedisCalls(Times times)
|
||||||
|
{
|
||||||
|
_mockDb.Verify(x =>
|
||||||
|
x.ScriptEvaluateAsync(It.IsAny<LuaScript>(), It.IsAny<object>(), It.IsAny<CommandFlags>()),
|
||||||
|
times);
|
||||||
|
}
|
||||||
|
|
||||||
|
private CustomRedisProcessingStrategy BuildProcessingStrategy(
|
||||||
|
bool isRedisConnected = true,
|
||||||
|
bool throwRedisTimeout = false,
|
||||||
|
IMemoryCache mockCache = null)
|
||||||
|
{
|
||||||
|
var mockRedisConnection = new Mock<IConnectionMultiplexer>();
|
||||||
|
|
||||||
|
mockRedisConnection.Setup(x => x.IsConnected).Returns(isRedisConnected);
|
||||||
|
|
||||||
|
_mockDb = new Mock<IDatabase>();
|
||||||
|
|
||||||
|
var mockScriptEvaluate = _mockDb
|
||||||
|
.Setup(x =>
|
||||||
|
x.ScriptEvaluateAsync(It.IsAny<LuaScript>(), It.IsAny<object>(), It.IsAny<CommandFlags>()));
|
||||||
|
|
||||||
|
if (throwRedisTimeout)
|
||||||
|
{
|
||||||
|
mockScriptEvaluate.ThrowsAsync(new RedisTimeoutException("Timeout", CommandStatus.WaitingToBeSent));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
mockScriptEvaluate.ReturnsAsync(RedisResult.Create(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRedisConnection
|
||||||
|
.Setup(x =>
|
||||||
|
x.GetDatabase(It.IsAny<int>(), It.IsAny<object>()))
|
||||||
|
.Returns(_mockDb.Object);
|
||||||
|
|
||||||
|
var mockLogger = new Mock<ILogger<CustomRedisProcessingStrategy>>();
|
||||||
|
var mockConfig = new Mock<IRateLimitConfiguration>();
|
||||||
|
|
||||||
|
mockCache ??= new Mock<IMemoryCache>().Object;
|
||||||
|
|
||||||
|
return new CustomRedisProcessingStrategy(mockRedisConnection.Object, mockConfig.Object,
|
||||||
|
mockLogger.Object, mockCache, _sampleSettings);
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user