1
0
mirror of https://github.com/bitwarden/server.git synced 2025-05-20 19:14:32 -05:00

Feature/phishing detection cronjob (#5512)

* Added caching to EF implementation. Added error handling and logging

* Refactored update method to use sqlbulkcopy instead of performing a round trip for each new insert

* Initial implementation for quartz job to get list of phishing domains

* Updated phishing domain settings to be its own interface

* Add phishing domain detection with checksum-based updates
This commit is contained in:
Conner Turnbull 2025-03-18 08:07:05 -04:00 committed by GitHub
parent 6e0df19ae4
commit 370a69a86f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 660 additions and 80 deletions

View File

@ -14,4 +14,11 @@ public class PhishingDomainsController(IPhishingDomainRepository phishingDomainR
var domains = await phishingDomainRepository.GetActivePhishingDomainsAsync();
return Ok(domains);
}
[HttpGet("checksum")]
public async Task<ActionResult<string>> GetChecksumAsync()
{
var checksum = await phishingDomainRepository.GetCurrentChecksumAsync();
return Ok(checksum);
}
}

View File

@ -58,6 +58,13 @@ public class JobsHostedService : BaseJobsHostedService
.StartNow()
.WithCronSchedule("0 0 * * * ?")
.Build();
var updatePhishingDomainsTrigger = TriggerBuilder.Create()
.WithIdentity("UpdatePhishingDomainsTrigger")
.StartNow()
.WithSimpleSchedule(x => x
.WithIntervalInHours(24)
.RepeatForever())
.Build();
var jobs = new List<Tuple<Type, ITrigger>>
@ -68,6 +75,7 @@ public class JobsHostedService : BaseJobsHostedService
new Tuple<Type, ITrigger>(typeof(ValidateUsersJob), everyTopOfTheSixthHourTrigger),
new Tuple<Type, ITrigger>(typeof(ValidateOrganizationsJob), everyTwelfthHourAndThirtyMinutesTrigger),
new Tuple<Type, ITrigger>(typeof(ValidateOrganizationDomainJob), validateOrganizationDomainTrigger),
new Tuple<Type, ITrigger>(typeof(UpdatePhishingDomainsJob), updatePhishingDomainsTrigger),
};
if (_globalSettings.SelfHosted && _globalSettings.EnableCloudCommunication)
@ -96,6 +104,7 @@ public class JobsHostedService : BaseJobsHostedService
services.AddTransient<ValidateUsersJob>();
services.AddTransient<ValidateOrganizationsJob>();
services.AddTransient<ValidateOrganizationDomainJob>();
services.AddTransient<UpdatePhishingDomainsJob>();
}
public static void AddCommercialSecretsManagerJobServices(IServiceCollection services)

View File

@ -0,0 +1,87 @@
using Bit.Core;
using Bit.Core.Jobs;
using Bit.Core.PhishingDomainFeatures.Interfaces;
using Bit.Core.Repositories;
using Bit.Core.Settings;
using Quartz;
namespace Bit.Api.Jobs;
public class UpdatePhishingDomainsJob : BaseJob
{
private readonly GlobalSettings _globalSettings;
private readonly IPhishingDomainRepository _phishingDomainRepository;
private readonly ICloudPhishingDomainQuery _cloudPhishingDomainQuery;
public UpdatePhishingDomainsJob(
GlobalSettings globalSettings,
IPhishingDomainRepository phishingDomainRepository,
ICloudPhishingDomainQuery cloudPhishingDomainQuery,
ILogger<UpdatePhishingDomainsJob> logger)
: base(logger)
{
_globalSettings = globalSettings;
_phishingDomainRepository = phishingDomainRepository;
_cloudPhishingDomainQuery = cloudPhishingDomainQuery;
}
protected override async Task ExecuteJobAsync(IJobExecutionContext context)
{
if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.UpdateUrl))
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Skipping phishing domain update. No URL configured.");
return;
}
if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Skipping phishing domain update. Cloud communication is disabled in global settings.");
return;
}
// Get the remote checksum
var remoteChecksum = await _cloudPhishingDomainQuery.GetRemoteChecksumAsync();
if (string.IsNullOrWhiteSpace(remoteChecksum))
{
_logger.LogWarning(Constants.BypassFiltersEventId, "Could not retrieve remote checksum. Skipping update.");
return;
}
// Get the current checksum from the database
var currentChecksum = await _phishingDomainRepository.GetCurrentChecksumAsync();
// Compare checksums to determine if update is needed
if (string.Equals(currentChecksum, remoteChecksum, StringComparison.OrdinalIgnoreCase))
{
_logger.LogInformation(Constants.BypassFiltersEventId,
"Phishing domains list is up to date (checksum: {Checksum}). Skipping update.",
currentChecksum);
return;
}
_logger.LogInformation(Constants.BypassFiltersEventId,
"Checksums differ (current: {CurrentChecksum}, remote: {RemoteChecksum}). Fetching updated domains from {Source}.",
currentChecksum, remoteChecksum, _globalSettings.SelfHosted ? "Bitwarden cloud API" : "external source");
try
{
var domains = await _cloudPhishingDomainQuery.GetPhishingDomainsAsync();
if (domains.Count > 0)
{
_logger.LogInformation(Constants.BypassFiltersEventId, "Updating {Count} phishing domains with checksum {Checksum}.",
domains.Count, remoteChecksum);
await _phishingDomainRepository.UpdatePhishingDomainsAsync(domains, remoteChecksum);
_logger.LogInformation(Constants.BypassFiltersEventId, "Successfully updated phishing domains.");
}
else
{
_logger.LogWarning(Constants.BypassFiltersEventId, "No valid domains found in the response. Skipping update.");
}
}
catch (Exception ex)
{
_logger.LogError(Constants.BypassFiltersEventId, ex, "Error updating phishing domains.");
}
}
}

View File

@ -177,6 +177,7 @@ public class Startup
services.AddBillingOperations();
services.AddReportingServices();
services.AddImportServices();
services.AddPhishingDomainServices(globalSettings);
// Authorization Handlers
services.AddAuthorizationHandlers();

View File

@ -2,6 +2,8 @@
using Bit.Api.Vault.AuthorizationHandlers.Collections;
using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Authorization;
using Bit.Core.IdentityServer;
using Bit.Core.PhishingDomainFeatures;
using Bit.Core.PhishingDomainFeatures.Interfaces;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Bit.Core.Vault.Authorization.SecurityTasks;
@ -106,4 +108,22 @@ public static class ServiceCollectionExtensions
services.AddScoped<IAuthorizationHandler, SecurityTaskAuthorizationHandler>();
services.AddScoped<IAuthorizationHandler, SecurityTaskOrganizationAuthorizationHandler>();
}
public static void AddPhishingDomainServices(this IServiceCollection services, GlobalSettings globalSettings)
{
services.AddHttpClient("PhishingDomains", client =>
{
client.DefaultRequestHeaders.Add("User-Agent", globalSettings.SelfHosted ? "Bitwarden Self-Hosted" : "Bitwarden");
client.Timeout = TimeSpan.FromSeconds(30);
});
if (globalSettings.SelfHosted)
{
services.AddScoped<ICloudPhishingDomainQuery, CloudPhishingDomainRelayQuery>();
}
else
{
services.AddScoped<ICloudPhishingDomainQuery, CloudPhishingDomainDirectQuery>();
}
}
}

View File

@ -37,6 +37,10 @@
},
"storage": {
"connectionString": "UseDevelopmentStorage=true"
},
"phishingDomain": {
"updateUrl": "https://phish.co.za/latest/phishing-domains-ACTIVE.txt",
"checksumUrl": "https://raw.githubusercontent.com/Phishing-Database/checksums/refs/heads/master/phishing-domains-ACTIVE.txt.sha256"
}
}
}

View File

@ -71,6 +71,9 @@
"accessKeySecret": "SECRET",
"region": "SECRET"
},
"phishingDomain": {
"updateUrl": "SECRET"
},
"distributedIpRateLimiting": {
"enabled": true,
"maxRedisTimeoutsThreshold": 10,

View File

@ -0,0 +1,104 @@
using Bit.Core.PhishingDomainFeatures.Interfaces;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
namespace Bit.Core.PhishingDomainFeatures;
/// <summary>
/// Implementation of ICloudPhishingDomainQuery for cloud environments
/// that directly calls the external phishing domain source
/// </summary>
public class CloudPhishingDomainDirectQuery : ICloudPhishingDomainQuery
{
private readonly IGlobalSettings _globalSettings;
private readonly IHttpClientFactory _httpClientFactory;
private readonly ILogger<CloudPhishingDomainDirectQuery> _logger;
public CloudPhishingDomainDirectQuery(
IGlobalSettings globalSettings,
IHttpClientFactory httpClientFactory,
ILogger<CloudPhishingDomainDirectQuery> logger)
{
_globalSettings = globalSettings;
_httpClientFactory = httpClientFactory;
_logger = logger;
}
public async Task<List<string>> GetPhishingDomainsAsync()
{
if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.UpdateUrl))
{
throw new InvalidOperationException("Phishing domain update URL is not configured.");
}
var httpClient = _httpClientFactory.CreateClient("PhishingDomains");
var response = await httpClient.GetAsync(_globalSettings.PhishingDomain.UpdateUrl);
response.EnsureSuccessStatusCode();
var content = await response.Content.ReadAsStringAsync();
return ParseDomains(content);
}
/// <summary>
/// Gets the SHA256 checksum of the remote phishing domains list
/// </summary>
/// <returns>The SHA256 checksum as a lowercase hex string</returns>
public async Task<string> GetRemoteChecksumAsync()
{
if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.ChecksumUrl))
{
_logger.LogWarning("Phishing domain checksum URL is not configured.");
return string.Empty;
}
try
{
var httpClient = _httpClientFactory.CreateClient("PhishingDomains");
var response = await httpClient.GetAsync(_globalSettings.PhishingDomain.ChecksumUrl);
response.EnsureSuccessStatusCode();
var content = await response.Content.ReadAsStringAsync();
return ParseChecksumResponse(content);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error retrieving phishing domain checksum from {Url}",
_globalSettings.PhishingDomain.ChecksumUrl);
return string.Empty;
}
}
/// <summary>
/// Parses a checksum response in the format "hash *filename"
/// </summary>
private static string ParseChecksumResponse(string checksumContent)
{
if (string.IsNullOrWhiteSpace(checksumContent))
{
return string.Empty;
}
// Format is typically "hash *filename"
var parts = checksumContent.Split(' ', 2);
if (parts.Length > 0)
{
return parts[0].Trim();
}
return string.Empty;
}
private static List<string> ParseDomains(string content)
{
if (string.IsNullOrWhiteSpace(content))
{
return [];
}
return content
.Split(['\r', '\n'], StringSplitOptions.RemoveEmptyEntries)
.Select(line => line.Trim())
.Where(line => !string.IsNullOrWhiteSpace(line) && !line.StartsWith("#"))
.ToList();
}
}

View File

@ -0,0 +1,66 @@
using Bit.Core.PhishingDomainFeatures.Interfaces;
using Bit.Core.Services;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
namespace Bit.Core.PhishingDomainFeatures;
/// <summary>
/// Implementation of ICloudPhishingDomainQuery for self-hosted environments
/// that relays the request to the Bitwarden cloud API
/// </summary>
public class CloudPhishingDomainRelayQuery : BaseIdentityClientService, ICloudPhishingDomainQuery
{
private readonly IGlobalSettings _globalSettings;
public CloudPhishingDomainRelayQuery(
IHttpClientFactory httpFactory,
IGlobalSettings globalSettings,
ILogger<CloudPhishingDomainRelayQuery> logger)
: base(
httpFactory,
globalSettings.Installation.ApiUri,
globalSettings.Installation.IdentityUri,
"api.installation",
$"installation.{globalSettings.Installation.Id}",
globalSettings.Installation.Key,
logger)
{
_globalSettings = globalSettings;
}
public async Task<List<string>> GetPhishingDomainsAsync()
{
if (!_globalSettings.SelfHosted || !_globalSettings.EnableCloudCommunication)
{
throw new InvalidOperationException("This query is only for self-hosted installations with cloud communication enabled.");
}
var result = await SendAsync<object, string[]>(HttpMethod.Get, "phishing-domains", null, true);
return result?.ToList() ?? new List<string>();
}
/// <summary>
/// Gets the SHA256 checksum of the remote phishing domains list
/// </summary>
/// <returns>The SHA256 checksum as a lowercase hex string</returns>
public async Task<string> GetRemoteChecksumAsync()
{
if (!_globalSettings.SelfHosted || !_globalSettings.EnableCloudCommunication)
{
throw new InvalidOperationException("This query is only for self-hosted installations with cloud communication enabled.");
}
try
{
// For self-hosted environments, we get the checksum from the Bitwarden cloud API
var result = await SendAsync<object, string>(HttpMethod.Get, "phishing-domains/checksum", null, true);
return result ?? string.Empty;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error retrieving phishing domain checksum from Bitwarden cloud API");
return string.Empty;
}
}
}

View File

@ -0,0 +1,7 @@
namespace Bit.Core.PhishingDomainFeatures.Interfaces;
public interface ICloudPhishingDomainQuery
{
Task<List<string>> GetPhishingDomainsAsync();
Task<string> GetRemoteChecksumAsync();
}

View File

@ -3,5 +3,6 @@
public interface IPhishingDomainRepository
{
Task<ICollection<string>> GetActivePhishingDomainsAsync();
Task UpdatePhishingDomainsAsync(IEnumerable<string> domains);
Task UpdatePhishingDomainsAsync(IEnumerable<string> domains, string checksum);
Task<string> GetCurrentChecksumAsync();
}

View File

@ -84,6 +84,7 @@ public class GlobalSettings : IGlobalSettings
public virtual ILaunchDarklySettings LaunchDarkly { get; set; } = new LaunchDarklySettings();
public virtual string DevelopmentDirectory { get; set; }
public virtual IWebPushSettings WebPush { get; set; } = new WebPushSettings();
public virtual IPhishingDomainSettings PhishingDomain { get; set; } = new PhishingDomainSettings();
public virtual bool EnableEmailVerification { get; set; }
public virtual string KdfDefaultHashKey { get; set; }
@ -634,6 +635,12 @@ public class GlobalSettings : IGlobalSettings
public int MaxNetworkRetries { get; set; } = 2;
}
public class PhishingDomainSettings : IPhishingDomainSettings
{
public string UpdateUrl { get; set; }
public string ChecksumUrl { get; set; }
}
public class DistributedIpRateLimitingSettings
{
public string RedisConnectionString { get; set; }

View File

@ -29,4 +29,5 @@ public interface IGlobalSettings
string DevelopmentDirectory { get; set; }
IWebPushSettings WebPush { get; set; }
GlobalSettings.EventLoggingSettings EventLogging { get; set; }
IPhishingDomainSettings PhishingDomain { get; set; }
}

View File

@ -0,0 +1,7 @@
namespace Bit.Core.Settings;
public interface IPhishingDomainSettings
{
string UpdateUrl { get; set; }
string ChecksumUrl { get; set; }
}

View File

@ -72,10 +72,4 @@ public static class DapperServiceCollectionExtensions
services.AddSingleton<IEventRepository, EventRepository>();
}
}
public static void AddDapper(this IServiceCollection services)
{
// Register repositories
services.AddSingleton<IPhishingDomainRepository, PhishingDomainRepository>();
}
}

View File

@ -5,6 +5,7 @@ using Bit.Core.Settings;
using Dapper;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Logging;
namespace Bit.Infrastructure.Dapper.Repositories;
@ -12,77 +13,152 @@ public class PhishingDomainRepository : IPhishingDomainRepository
{
private readonly string _connectionString;
private readonly IDistributedCache _cache;
private const string _cacheKey = "PhishingDomains";
private static readonly DistributedCacheEntryOptions _cacheOptions = new DistributedCacheEntryOptions
private readonly ILogger<PhishingDomainRepository> _logger;
private const string _cacheKey = "PhishingDomains_v1";
private static readonly DistributedCacheEntryOptions _cacheOptions = new()
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(24) // Cache for 24 hours
AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(24),
SlidingExpiration = TimeSpan.FromHours(1)
};
public PhishingDomainRepository(GlobalSettings globalSettings, IDistributedCache cache)
: this(globalSettings.SqlServer.ConnectionString, cache)
public PhishingDomainRepository(
GlobalSettings globalSettings,
IDistributedCache cache,
ILogger<PhishingDomainRepository> logger)
: this(globalSettings.SqlServer.ConnectionString, cache, logger)
{ }
public PhishingDomainRepository(string connectionString, IDistributedCache cache)
public PhishingDomainRepository(
string connectionString,
IDistributedCache cache,
ILogger<PhishingDomainRepository> logger)
{
_connectionString = connectionString;
_cache = cache;
_logger = logger;
}
public async Task<ICollection<string>> GetActivePhishingDomainsAsync()
{
// Try to get from cache first
var cachedDomains = await _cache.GetStringAsync(_cacheKey);
if (!string.IsNullOrEmpty(cachedDomains))
try
{
return JsonSerializer.Deserialize<ICollection<string>>(cachedDomains) ?? new List<string>();
var cachedDomains = await _cache.GetStringAsync(_cacheKey);
if (!string.IsNullOrEmpty(cachedDomains))
{
_logger.LogDebug("Retrieved phishing domains from cache");
return JsonSerializer.Deserialize<ICollection<string>>(cachedDomains) ?? [];
}
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to retrieve phishing domains from cache");
}
// If not in cache, get from database
using (var connection = new SqlConnection(_connectionString))
await using var connection = new SqlConnection(_connectionString);
var results = await connection.QueryAsync<string>(
"[dbo].[PhishingDomain_ReadAll]",
commandType: CommandType.StoredProcedure);
var domains = results.AsList();
try
{
var results = await connection.QueryAsync<string>(
"[dbo].[PhishingDomain_ReadAll]",
commandType: CommandType.StoredProcedure);
var domains = results.AsList();
// Store in cache
await _cache.SetStringAsync(
_cacheKey,
JsonSerializer.Serialize(domains),
_cacheOptions);
return domains;
_logger.LogDebug("Stored {Count} phishing domains in cache", domains.Count);
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to store phishing domains in cache");
}
return domains;
}
public async Task<string> GetCurrentChecksumAsync()
{
try
{
await using var connection = new SqlConnection(_connectionString);
var checksum = await connection.QueryFirstOrDefaultAsync<string>(
"[dbo].[PhishingDomain_ReadChecksum]",
commandType: CommandType.StoredProcedure);
return checksum ?? string.Empty;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error retrieving phishing domain checksum from database");
return string.Empty;
}
}
public async Task UpdatePhishingDomainsAsync(IEnumerable<string> domains)
public async Task UpdatePhishingDomainsAsync(IEnumerable<string> domains, string checksum)
{
using (var connection = new SqlConnection(_connectionString))
var domainsList = domains.ToList();
_logger.LogInformation("Beginning bulk update of {Count} phishing domains with checksum {Checksum}",
domainsList.Count, checksum);
await using var connection = new SqlConnection(_connectionString);
await connection.OpenAsync();
await using var transaction = connection.BeginTransaction();
try
{
await connection.ExecuteAsync(
"[dbo].[PhishingDomain_DeleteAll]",
transaction: transaction,
commandType: CommandType.StoredProcedure);
foreach (var domain in domains)
var dataTable = new DataTable();
dataTable.Columns.Add("Id", typeof(Guid));
dataTable.Columns.Add("Domain", typeof(string));
dataTable.Columns.Add("Checksum", typeof(string));
dataTable.PrimaryKey = [dataTable.Columns["Id"]];
foreach (var domain in domainsList)
{
await connection.ExecuteAsync(
"[dbo].[PhishingDomain_Create]",
new
{
Id = Guid.NewGuid(),
Domain = domain,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow
},
commandType: CommandType.StoredProcedure);
dataTable.Rows.Add(Guid.NewGuid(), domain, checksum);
}
using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.Default, transaction);
bulkCopy.DestinationTableName = "[dbo].[PhishingDomain]";
bulkCopy.BatchSize = 10000;
bulkCopy.ColumnMappings.Add("Id", "Id");
bulkCopy.ColumnMappings.Add("Domain", "Domain");
bulkCopy.ColumnMappings.Add("Checksum", "Checksum");
await bulkCopy.WriteToServerAsync(dataTable);
await transaction.CommitAsync();
_logger.LogInformation("Successfully bulk updated {Count} phishing domains", domainsList.Count);
}
catch (Exception ex)
{
await transaction.RollbackAsync();
_logger.LogError(ex, "Failed to bulk update phishing domains");
throw;
}
// Update cache with new domains
await _cache.SetStringAsync(
_cacheKey,
JsonSerializer.Serialize(domains),
_cacheOptions);
try
{
await _cache.SetStringAsync(
_cacheKey,
JsonSerializer.Serialize(domainsList),
_cacheOptions);
_logger.LogDebug("Updated phishing domains cache after update operation");
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to update phishing domains in cache");
}
}
}

View File

@ -11,7 +11,6 @@ public class PhishingDomain
[MaxLength(255)]
public string Domain { get; set; }
public DateTime CreationDate { get; set; }
public DateTime RevisionDate { get; set; }
[MaxLength(64)]
public string Checksum { get; set; }
}

View File

@ -111,6 +111,7 @@ public class DatabaseContext : DbContext
var eOrganizationConnection = builder.Entity<OrganizationConnection>();
var eOrganizationDomain = builder.Entity<OrganizationDomain>();
var aWebAuthnCredential = builder.Entity<WebAuthnCredential>();
var ePhishingDomain = builder.Entity<PhishingDomain>();
// Shadow property configurations go here
@ -127,6 +128,7 @@ public class DatabaseContext : DbContext
eOrganizationConnection.Property(c => c.Id).ValueGeneratedNever();
eOrganizationDomain.Property(ar => ar.Id).ValueGeneratedNever();
aWebAuthnCredential.Property(ar => ar.Id).ValueGeneratedNever();
ePhishingDomain.Property(ar => ar.Id).ValueGeneratedNever();
eCollectionCipher.HasKey(cc => new { cc.CollectionId, cc.CipherId });
eCollectionUser.HasKey(cu => new { cu.CollectionId, cu.OrganizationUserId });
@ -167,6 +169,7 @@ public class DatabaseContext : DbContext
eOrganizationConnection.ToTable(nameof(OrganizationConnection));
eOrganizationDomain.ToTable(nameof(OrganizationDomain));
aWebAuthnCredential.ToTable(nameof(WebAuthnCredential));
ePhishingDomain.ToTable(nameof(PhishingDomain));
ConfigureDateTimeUtcQueries(builder);
}

View File

@ -1,50 +1,167 @@
using Bit.Core.Repositories;
using Bit.Infrastructure.EntityFramework.Models;
using System.Data;
using System.Text.Json;
using Bit.Core.Repositories;
using Microsoft.Data.SqlClient;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
namespace Bit.Infrastructure.EntityFramework.Repositories;
public class PhishingDomainRepository : IPhishingDomainRepository
{
private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly IDistributedCache _cache;
private readonly ILogger<PhishingDomainRepository> _logger;
private const string _cacheKey = "PhishingDomains_v1";
private static readonly DistributedCacheEntryOptions _cacheOptions = new()
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(24),
SlidingExpiration = TimeSpan.FromHours(1)
};
public PhishingDomainRepository(IServiceScopeFactory serviceScopeFactory)
public PhishingDomainRepository(
IServiceScopeFactory serviceScopeFactory,
IDistributedCache cache,
ILogger<PhishingDomainRepository> logger)
{
_serviceScopeFactory = serviceScopeFactory;
_cache = cache;
_logger = logger;
}
public async Task<ICollection<string>> GetActivePhishingDomainsAsync()
{
using (var scope = _serviceScopeFactory.CreateScope())
try
{
var cachedDomains = await _cache.GetStringAsync(_cacheKey);
if (!string.IsNullOrEmpty(cachedDomains))
{
_logger.LogDebug("Retrieved phishing domains from cache");
return JsonSerializer.Deserialize<ICollection<string>>(cachedDomains) ?? [];
}
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to retrieve phishing domains from cache");
}
using var scope = _serviceScopeFactory.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
var domains = await dbContext.PhishingDomains
.Select(d => d.Domain)
.ToListAsync();
try
{
await _cache.SetStringAsync(
_cacheKey,
JsonSerializer.Serialize(domains),
_cacheOptions);
_logger.LogDebug("Stored {Count} phishing domains in cache", domains.Count);
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to store phishing domains in cache");
}
return domains;
}
public async Task<string> GetCurrentChecksumAsync()
{
try
{
using var scope = _serviceScopeFactory.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
var domains = await dbContext.PhishingDomains
.Select(d => d.Domain)
.ToListAsync();
return domains;
// Get the first checksum in the database (there should only be one set of domains with the same checksum)
var checksum = await dbContext.PhishingDomains
.Select(d => d.Checksum)
.FirstOrDefaultAsync();
return checksum ?? string.Empty;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error retrieving phishing domain checksum from database");
return string.Empty;
}
}
public async Task UpdatePhishingDomainsAsync(IEnumerable<string> domains)
public async Task UpdatePhishingDomainsAsync(IEnumerable<string> domains, string checksum)
{
using (var scope = _serviceScopeFactory.CreateScope())
var domainsList = domains.ToList();
_logger.LogInformation("Beginning bulk update of {Count} phishing domains with checksum {Checksum}",
domainsList.Count, checksum);
using var scope = _serviceScopeFactory.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
var connection = dbContext.Database.GetDbConnection();
var connectionString = connection.ConnectionString;
await using var sqlConnection = new SqlConnection(connectionString);
await sqlConnection.OpenAsync();
await using var transaction = sqlConnection.BeginTransaction();
try
{
var dbContext = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
await using var command = sqlConnection.CreateCommand();
command.Transaction = transaction;
command.CommandText = "[dbo].[PhishingDomain_DeleteAll]";
command.CommandType = CommandType.StoredProcedure;
await command.ExecuteNonQueryAsync();
// Clear existing domains
await dbContext.PhishingDomains.ExecuteDeleteAsync();
var dataTable = new DataTable();
dataTable.Columns.Add("Id", typeof(Guid));
dataTable.Columns.Add("Domain", typeof(string));
dataTable.Columns.Add("Checksum", typeof(string));
// Add new domains
var phishingDomains = domains.Select(d => new PhishingDomain
dataTable.PrimaryKey = [dataTable.Columns["Id"]];
foreach (var domain in domainsList)
{
Id = Guid.NewGuid(),
Domain = d,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow
});
await dbContext.PhishingDomains.AddRangeAsync(phishingDomains);
await dbContext.SaveChangesAsync();
dataTable.Rows.Add(Guid.NewGuid(), domain, checksum);
}
using var bulkCopy = new SqlBulkCopy(sqlConnection, SqlBulkCopyOptions.Default, transaction);
bulkCopy.DestinationTableName = "[dbo].[PhishingDomain]";
bulkCopy.BatchSize = 10000;
bulkCopy.ColumnMappings.Add("Id", "Id");
bulkCopy.ColumnMappings.Add("Domain", "Domain");
bulkCopy.ColumnMappings.Add("Checksum", "Checksum");
await bulkCopy.WriteToServerAsync(dataTable);
await transaction.CommitAsync();
_logger.LogInformation("Successfully bulk updated {Count} phishing domains", domainsList.Count);
}
catch (Exception ex)
{
await transaction.RollbackAsync();
_logger.LogError(ex, "Failed to bulk update phishing domains");
throw;
}
try
{
await _cache.SetStringAsync(
_cacheKey,
JsonSerializer.Serialize(domainsList),
_cacheOptions);
_logger.LogDebug("Updated phishing domains cache after update operation");
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to update phishing domains in cache");
}
}
}

View File

@ -1,8 +1,7 @@
CREATE PROCEDURE [dbo].[PhishingDomain_Create]
@Id UNIQUEIDENTIFIER,
@Domain NVARCHAR(255),
@CreationDate DATETIME2(7),
@RevisionDate DATETIME2(7)
@Checksum NVARCHAR(64)
AS
BEGIN
SET NOCOUNT ON
@ -11,14 +10,12 @@ BEGIN
(
[Id],
[Domain],
[CreationDate],
[RevisionDate]
[Checksum]
)
VALUES
(
@Id,
@Domain,
@CreationDate,
@RevisionDate
@Checksum
)
END

View File

@ -0,0 +1,10 @@
CREATE PROCEDURE [dbo].[PhishingDomain_ReadChecksum]
AS
BEGIN
SET NOCOUNT ON
SELECT TOP 1
[Checksum]
FROM
[dbo].[PhishingDomain]
END

View File

@ -1,8 +1,7 @@
CREATE TABLE [dbo].[PhishingDomain] (
[Id] UNIQUEIDENTIFIER NOT NULL,
[Domain] NVARCHAR(255) NOT NULL,
[CreationDate] DATETIME2(7) NOT NULL,
[RevisionDate] DATETIME2(7) NOT NULL,
[Id] UNIQUEIDENTIFIER NOT NULL,
[Domain] NVARCHAR(255) NOT NULL,
[Checksum] NVARCHAR(64) NULL,
CONSTRAINT [PK_PhishingDomain] PRIMARY KEY CLUSTERED ([Id] ASC)
);

View File

@ -0,0 +1,61 @@
-- Update PhishingDomain table to use Checksum instead of dates
IF EXISTS (SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'PhishingDomain' AND COLUMN_NAME = 'CreationDate')
BEGIN
-- Add Checksum column
ALTER TABLE [dbo].[PhishingDomain]
ADD [Checksum] NVARCHAR(64) NULL;
-- Drop old columns
ALTER TABLE [dbo].[PhishingDomain]
DROP COLUMN [CreationDate], [RevisionDate];
END
GO
-- Update PhishingDomain_Create stored procedure
IF OBJECT_ID('[dbo].[PhishingDomain_Create]') IS NOT NULL
BEGIN
DROP PROCEDURE [dbo].[PhishingDomain_Create]
END
GO
CREATE PROCEDURE [dbo].[PhishingDomain_Create]
@Id UNIQUEIDENTIFIER,
@Domain NVARCHAR(255),
@Checksum NVARCHAR(64)
AS
BEGIN
SET NOCOUNT ON
INSERT INTO [dbo].[PhishingDomain]
(
[Id],
[Domain],
[Checksum]
)
VALUES
(
@Id,
@Domain,
@Checksum
)
END
GO
-- Create PhishingDomain_ReadChecksum stored procedure
IF OBJECT_ID('[dbo].[PhishingDomain_ReadChecksum]') IS NOT NULL
BEGIN
DROP PROCEDURE [dbo].[PhishingDomain_ReadChecksum]
END
GO
CREATE PROCEDURE [dbo].[PhishingDomain_ReadChecksum]
AS
BEGIN
SET NOCOUNT ON
SELECT TOP 1
[Checksum]
FROM
[dbo].[PhishingDomain]
END
GO