diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj
index 3177d15d35..f07a97e2ec 100644
--- a/src/Core/Core.csproj
+++ b/src/Core/Core.csproj
@@ -39,6 +39,7 @@
+
diff --git a/src/Infrastructure.EntityFramework/Configurations/CacheEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/Configurations/CacheEntityTypeConfiguration.cs
new file mode 100644
index 0000000000..7d7d88a6cd
--- /dev/null
+++ b/src/Infrastructure.EntityFramework/Configurations/CacheEntityTypeConfiguration.cs
@@ -0,0 +1,25 @@
+using Bit.Infrastructure.EntityFramework.Models;
+using Microsoft.EntityFrameworkCore;
+using Microsoft.EntityFrameworkCore.Metadata.Builders;
+
+namespace Bit.Infrastructure.EntityFramework.Configurations;
+
+public class CacheEntityTypeConfiguration : IEntityTypeConfiguration
+{
+ public void Configure(EntityTypeBuilder builder)
+ {
+ builder
+ .HasKey(s => s.Id)
+ .IsClustered();
+
+ builder
+ .Property(s => s.Id)
+ .ValueGeneratedNever();
+
+ builder
+ .HasIndex(s => s.ExpiresAtTime)
+ .IsClustered(false);
+
+ builder.ToTable(nameof(Cache));
+ }
+}
diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkCache.cs b/src/Infrastructure.EntityFramework/EntityFrameworkCache.cs
new file mode 100644
index 0000000000..1bffa1c77c
--- /dev/null
+++ b/src/Infrastructure.EntityFramework/EntityFrameworkCache.cs
@@ -0,0 +1,315 @@
+using Bit.Infrastructure.EntityFramework.Models;
+using Bit.Infrastructure.EntityFramework.Repositories;
+using Microsoft.EntityFrameworkCore;
+using Microsoft.Extensions.Caching.Distributed;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Bit.Infrastructure.EntityFramework;
+
+public class EntityFrameworkCache : IDistributedCache
+{
+#if DEBUG
+ // Used for debugging in tests
+ public Task scanTask;
+#endif
+ private static readonly TimeSpan _defaultSlidingExpiration = TimeSpan.FromMinutes(20);
+ private static readonly TimeSpan _expiredItemsDeletionInterval = TimeSpan.FromMinutes(30);
+ private DateTimeOffset _lastExpirationScan;
+ private readonly Action _deleteExpiredCachedItemsDelegate;
+ private readonly object _mutex = new();
+ private readonly IServiceScopeFactory _serviceScopeFactory;
+ private readonly TimeProvider _timeProvider;
+
+ public EntityFrameworkCache(
+ IServiceScopeFactory serviceScopeFactory,
+ TimeProvider timeProvider = null)
+ {
+ _deleteExpiredCachedItemsDelegate = DeleteExpiredCacheItems;
+ _serviceScopeFactory = serviceScopeFactory;
+ _timeProvider = timeProvider ?? TimeProvider.System;
+ }
+
+ public byte[] Get(string key)
+ {
+ ArgumentNullException.ThrowIfNull(key);
+
+ using var scope = _serviceScopeFactory.CreateScope();
+ var dbContext = GetDatabaseContext(scope);
+ var cache = dbContext.Cache
+ .Where(c => c.Id == key && _timeProvider.GetUtcNow().DateTime <= c.ExpiresAtTime)
+ .SingleOrDefault();
+
+ if (cache == null)
+ {
+ return null;
+ }
+
+ if (UpdateCacheExpiration(cache))
+ {
+ dbContext.SaveChanges();
+ }
+
+ ScanForExpiredItemsIfRequired();
+ return cache?.Value;
+ }
+
+ public async Task GetAsync(string key, CancellationToken token = default)
+ {
+ ArgumentNullException.ThrowIfNull(key);
+ token.ThrowIfCancellationRequested();
+
+ using var scope = _serviceScopeFactory.CreateScope();
+ var dbContext = GetDatabaseContext(scope);
+ var cache = await dbContext.Cache
+ .Where(c => c.Id == key && _timeProvider.GetUtcNow().DateTime <= c.ExpiresAtTime)
+ .SingleOrDefaultAsync(cancellationToken: token);
+
+ if (cache == null)
+ {
+ return null;
+ }
+
+ if (UpdateCacheExpiration(cache))
+ {
+ await dbContext.SaveChangesAsync(token);
+ }
+
+ ScanForExpiredItemsIfRequired();
+ return cache?.Value;
+ }
+
+ public void Refresh(string key) => Get(key);
+
+ public Task RefreshAsync(string key, CancellationToken token = default) => GetAsync(key, token);
+
+ public void Remove(string key)
+ {
+ ArgumentNullException.ThrowIfNull(key);
+
+ using var scope = _serviceScopeFactory.CreateScope();
+ GetDatabaseContext(scope).Cache
+ .Where(c => c.Id == key)
+ .ExecuteDelete();
+
+ ScanForExpiredItemsIfRequired();
+ }
+
+ public async Task RemoveAsync(string key, CancellationToken token = default)
+ {
+ ArgumentNullException.ThrowIfNull(key);
+
+ token.ThrowIfCancellationRequested();
+ using var scope = _serviceScopeFactory.CreateScope();
+ await GetDatabaseContext(scope).Cache
+ .Where(c => c.Id == key)
+ .ExecuteDeleteAsync(cancellationToken: token);
+
+ ScanForExpiredItemsIfRequired();
+ }
+
+ public void Set(string key, byte[] value, DistributedCacheEntryOptions options)
+ {
+ ArgumentNullException.ThrowIfNull(key);
+ ArgumentNullException.ThrowIfNull(value);
+ ArgumentNullException.ThrowIfNull(options);
+
+ using var scope = _serviceScopeFactory.CreateScope();
+ var dbContext = GetDatabaseContext(scope);
+ var cache = dbContext.Cache.Find(key);
+ var insert = cache == null;
+ cache = SetCache(cache, key, value, options);
+ if (insert)
+ {
+ dbContext.Add(cache);
+ }
+
+ try
+ {
+ dbContext.SaveChanges();
+ }
+ catch (DbUpdateException e)
+ {
+ if (IsDuplicateKeyException(e))
+ {
+ // There is a possibility that multiple requests can try to add the same item to the cache, in
+ // which case we receive a 'duplicate key' exception on the primary key column.
+ }
+ else
+ {
+ throw;
+ }
+ }
+
+ ScanForExpiredItemsIfRequired();
+ }
+
+ public async Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default)
+ {
+ ArgumentNullException.ThrowIfNull(key);
+ ArgumentNullException.ThrowIfNull(value);
+ ArgumentNullException.ThrowIfNull(options);
+
+ token.ThrowIfCancellationRequested();
+
+ using var scope = _serviceScopeFactory.CreateScope();
+ var dbContext = GetDatabaseContext(scope);
+ var cache = await dbContext.Cache.FindAsync(new object[] { key }, cancellationToken: token);
+ var insert = cache == null;
+ cache = SetCache(cache, key, value, options);
+ if (insert)
+ {
+ await dbContext.AddAsync(cache, token);
+ }
+
+ try
+ {
+ await dbContext.SaveChangesAsync(token);
+ }
+ catch (DbUpdateException e)
+ {
+ if (IsDuplicateKeyException(e))
+ {
+ // There is a possibility that multiple requests can try to add the same item to the cache, in
+ // which case we receive a 'duplicate key' exception on the primary key column.
+ }
+ else
+ {
+ throw;
+ }
+ }
+
+ ScanForExpiredItemsIfRequired();
+ }
+
+ private Cache SetCache(Cache cache, string key, byte[] value, DistributedCacheEntryOptions options)
+ {
+ var utcNow = _timeProvider.GetUtcNow().DateTime;
+
+ // resolve options
+ if (!options.AbsoluteExpiration.HasValue &&
+ !options.AbsoluteExpirationRelativeToNow.HasValue &&
+ !options.SlidingExpiration.HasValue)
+ {
+ options = new DistributedCacheEntryOptions
+ {
+ SlidingExpiration = _defaultSlidingExpiration
+ };
+ }
+
+ if (cache == null)
+ {
+ // do an insert
+ cache = new Cache { Id = key };
+ }
+
+ var slidingExpiration = (long?)options.SlidingExpiration?.TotalSeconds;
+
+ // calculate absolute expiration
+ DateTime? absoluteExpiration = null;
+ if (options.AbsoluteExpirationRelativeToNow.HasValue)
+ {
+ absoluteExpiration = utcNow.Add(options.AbsoluteExpirationRelativeToNow.Value);
+ }
+ else if (options.AbsoluteExpiration.HasValue)
+ {
+ if (options.AbsoluteExpiration.Value <= utcNow)
+ {
+ throw new InvalidOperationException("The absolute expiration value must be in the future.");
+ }
+
+ absoluteExpiration = options.AbsoluteExpiration.Value.DateTime;
+ }
+
+ // set values on cache
+ cache.Value = value;
+ cache.SlidingExpirationInSeconds = slidingExpiration;
+ cache.AbsoluteExpiration = absoluteExpiration;
+ if (slidingExpiration.HasValue)
+ {
+ cache.ExpiresAtTime = utcNow.AddSeconds(slidingExpiration.Value);
+ }
+ else if (absoluteExpiration.HasValue)
+ {
+ cache.ExpiresAtTime = absoluteExpiration.Value;
+ }
+ else
+ {
+ throw new InvalidOperationException("Either absolute or sliding expiration needs to be provided.");
+ }
+
+ return cache;
+ }
+
+ private bool UpdateCacheExpiration(Cache cache)
+ {
+ var utcNow = _timeProvider.GetUtcNow().DateTime;
+ if (cache.SlidingExpirationInSeconds.HasValue && (cache.AbsoluteExpiration.HasValue || cache.AbsoluteExpiration != cache.ExpiresAtTime))
+ {
+ if (cache.AbsoluteExpiration.HasValue && (cache.AbsoluteExpiration.Value - utcNow).TotalSeconds <= cache.SlidingExpirationInSeconds)
+ {
+ cache.ExpiresAtTime = cache.AbsoluteExpiration.Value;
+ }
+ else
+ {
+ cache.ExpiresAtTime = utcNow.AddSeconds(cache.SlidingExpirationInSeconds.Value);
+ }
+ return true;
+ }
+ return false;
+ }
+
+ private void ScanForExpiredItemsIfRequired()
+ {
+ lock (_mutex)
+ {
+ var utcNow = _timeProvider.GetUtcNow().DateTime;
+ if ((utcNow - _lastExpirationScan) > _expiredItemsDeletionInterval)
+ {
+ _lastExpirationScan = utcNow;
+#if DEBUG
+ scanTask =
+#endif
+ Task.Run(_deleteExpiredCachedItemsDelegate);
+ }
+ }
+ }
+
+ private void DeleteExpiredCacheItems()
+ {
+ using var scope = _serviceScopeFactory.CreateScope();
+ GetDatabaseContext(scope).Cache
+ .Where(c => _timeProvider.GetUtcNow().DateTime > c.ExpiresAtTime)
+ .ExecuteDelete();
+ }
+
+ private DatabaseContext GetDatabaseContext(IServiceScope serviceScope)
+ {
+ return serviceScope.ServiceProvider.GetRequiredService();
+ }
+
+ private static bool IsDuplicateKeyException(DbUpdateException e)
+ {
+ // MySQL
+ if (e.InnerException is MySqlConnector.MySqlException myEx)
+ {
+ return myEx.ErrorCode == MySqlConnector.MySqlErrorCode.DuplicateKeyEntry;
+ }
+ // SQL Server
+ else if (e.InnerException is Microsoft.Data.SqlClient.SqlException msEx)
+ {
+ return msEx.Errors != null &&
+ msEx.Errors.Cast().Any(error => error.Number == 2627);
+ }
+ // Postgres
+ else if (e.InnerException is Npgsql.PostgresException pgEx)
+ {
+ return pgEx.SqlState == "23505";
+ }
+ // Sqlite
+ else if (e.InnerException is Microsoft.Data.Sqlite.SqliteException liteEx)
+ {
+ return liteEx.SqliteErrorCode == 19 && liteEx.SqliteExtendedErrorCode == 1555;
+ }
+ return false;
+ }
+}
diff --git a/src/Infrastructure.EntityFramework/Models/Cache.cs b/src/Infrastructure.EntityFramework/Models/Cache.cs
new file mode 100644
index 0000000000..f03c09d8dc
--- /dev/null
+++ b/src/Infrastructure.EntityFramework/Models/Cache.cs
@@ -0,0 +1,13 @@
+using System.ComponentModel.DataAnnotations;
+
+namespace Bit.Infrastructure.EntityFramework.Models;
+
+public class Cache
+{
+ [StringLength(449)]
+ public string Id { get; set; }
+ public byte[] Value { get; set; }
+ public DateTime ExpiresAtTime { get; set; }
+ public long? SlidingExpirationInSeconds { get; set; }
+ public DateTime? AbsoluteExpiration { get; set; }
+}
diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs
index 8712e0c17d..f1d514d22e 100644
--- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs
+++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs
@@ -32,6 +32,7 @@ public class DatabaseContext : DbContext
public DbSet GroupSecretAccessPolicy { get; set; }
public DbSet ServiceAccountSecretAccessPolicy { get; set; }
public DbSet ApiKeys { get; set; }
+ public DbSet Cache { get; set; }
public DbSet Ciphers { get; set; }
public DbSet Collections { get; set; }
public DbSet CollectionCiphers { get; set; }
diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs
index f381305745..3f5b464b58 100644
--- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs
+++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs
@@ -69,41 +69,7 @@ public static class ServiceCollectionExtensions
{
public static SupportedDatabaseProviders AddDatabaseRepositories(this IServiceCollection services, GlobalSettings globalSettings)
{
- var selectedDatabaseProvider = globalSettings.DatabaseProvider;
- var provider = SupportedDatabaseProviders.SqlServer;
- var connectionString = string.Empty;
-
- if (!string.IsNullOrWhiteSpace(selectedDatabaseProvider))
- {
- switch (selectedDatabaseProvider.ToLowerInvariant())
- {
- case "postgres":
- case "postgresql":
- provider = SupportedDatabaseProviders.Postgres;
- connectionString = globalSettings.PostgreSql.ConnectionString;
- break;
- case "mysql":
- case "mariadb":
- provider = SupportedDatabaseProviders.MySql;
- connectionString = globalSettings.MySql.ConnectionString;
- break;
- case "sqlite":
- provider = SupportedDatabaseProviders.Sqlite;
- connectionString = globalSettings.Sqlite.ConnectionString;
- break;
- case "sqlserver":
- connectionString = globalSettings.SqlServer.ConnectionString;
- break;
- default:
- break;
- }
- }
- else
- {
- // Default to attempting to use SqlServer connection string if globalSettings.DatabaseProvider has no value.
- connectionString = globalSettings.SqlServer.ConnectionString;
- }
-
+ var (provider, connectionString) = GetDatabaseProvider(globalSettings);
services.SetupEntityFramework(connectionString, provider);
if (provider != SupportedDatabaseProviders.SqlServer)
@@ -730,7 +696,20 @@ public static class ServiceCollectionExtensions
}
else
{
- services.AddDistributedMemoryCache();
+ var (databaseProvider, databaseConnectionString) = GetDatabaseProvider(globalSettings);
+ if (databaseProvider == SupportedDatabaseProviders.SqlServer)
+ {
+ services.AddDistributedSqlServerCache(o =>
+ {
+ o.ConnectionString = databaseConnectionString;
+ o.SchemaName = "dbo";
+ o.TableName = "Cache";
+ });
+ }
+ else
+ {
+ services.AddSingleton();
+ }
}
if (!string.IsNullOrEmpty(globalSettings.DistributedCache?.Cosmos?.ConnectionString))
@@ -746,7 +725,7 @@ public static class ServiceCollectionExtensions
}
else
{
- services.AddKeyedSingleton("persistent");
+ services.AddKeyedSingleton("persistent", (s, _) => s.GetRequiredService());
}
}
@@ -762,4 +741,45 @@ public static class ServiceCollectionExtensions
return services;
}
+
+ private static (SupportedDatabaseProviders provider, string connectionString)
+ GetDatabaseProvider(GlobalSettings globalSettings)
+ {
+ var selectedDatabaseProvider = globalSettings.DatabaseProvider;
+ var provider = SupportedDatabaseProviders.SqlServer;
+ var connectionString = string.Empty;
+
+ if (!string.IsNullOrWhiteSpace(selectedDatabaseProvider))
+ {
+ switch (selectedDatabaseProvider.ToLowerInvariant())
+ {
+ case "postgres":
+ case "postgresql":
+ provider = SupportedDatabaseProviders.Postgres;
+ connectionString = globalSettings.PostgreSql.ConnectionString;
+ break;
+ case "mysql":
+ case "mariadb":
+ provider = SupportedDatabaseProviders.MySql;
+ connectionString = globalSettings.MySql.ConnectionString;
+ break;
+ case "sqlite":
+ provider = SupportedDatabaseProviders.Sqlite;
+ connectionString = globalSettings.Sqlite.ConnectionString;
+ break;
+ case "sqlserver":
+ connectionString = globalSettings.SqlServer.ConnectionString;
+ break;
+ default:
+ break;
+ }
+ }
+ else
+ {
+ // Default to attempting to use SqlServer connection string if globalSettings.DatabaseProvider has no value.
+ connectionString = globalSettings.SqlServer.ConnectionString;
+ }
+
+ return (provider, connectionString);
+ }
}
diff --git a/src/Sql/dbo/Tables/Cache.sql b/src/Sql/dbo/Tables/Cache.sql
new file mode 100644
index 0000000000..b66d0dd61a
--- /dev/null
+++ b/src/Sql/dbo/Tables/Cache.sql
@@ -0,0 +1,14 @@
+CREATE TABLE [dbo].[Cache]
+(
+ [Id] NVARCHAR (449) NOT NULL,
+ [Value] VARBINARY (MAX) NOT NULL,
+ [ExpiresAtTime] DATETIMEOFFSET (7) NOT NULL,
+ [SlidingExpirationInSeconds] BIGINT NULL,
+ [AbsoluteExpiration] DATETIMEOFFSET (7) NULL,
+ CONSTRAINT [PK_Cache] PRIMARY KEY CLUSTERED ([Id] ASC)
+);
+GO
+
+CREATE NONCLUSTERED INDEX [IX_Cache_ExpiresAtTime]
+ ON [dbo].[Cache]([ExpiresAtTime] ASC);
+GO
diff --git a/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs b/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs
index 2c12890ca1..2e55426e78 100644
--- a/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs
+++ b/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs
@@ -3,9 +3,11 @@ using Bit.Core.Enums;
using Bit.Core.Settings;
using Bit.Infrastructure.Dapper;
using Bit.Infrastructure.EntityFramework;
+using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Time.Testing;
using Xunit.Sdk;
namespace Bit.Infrastructure.IntegrationTest;
@@ -13,6 +15,7 @@ namespace Bit.Infrastructure.IntegrationTest;
public class DatabaseDataAttribute : DataAttribute
{
public bool SelfHosted { get; set; }
+ public bool UseFakeTimeProvider { get; set; }
public override IEnumerable