diff --git a/src/Api/Controllers/PhishingDomainsController.cs b/src/Api/Controllers/PhishingDomainsController.cs index 433db97ccc..f35522c9d4 100644 --- a/src/Api/Controllers/PhishingDomainsController.cs +++ b/src/Api/Controllers/PhishingDomainsController.cs @@ -14,4 +14,11 @@ public class PhishingDomainsController(IPhishingDomainRepository phishingDomainR var domains = await phishingDomainRepository.GetActivePhishingDomainsAsync(); return Ok(domains); } + + [HttpGet("checksum")] + public async Task> GetChecksumAsync() + { + var checksum = await phishingDomainRepository.GetCurrentChecksumAsync(); + return Ok(checksum); + } } diff --git a/src/Api/Jobs/JobsHostedService.cs b/src/Api/Jobs/JobsHostedService.cs index acd95a0213..57b827a8be 100644 --- a/src/Api/Jobs/JobsHostedService.cs +++ b/src/Api/Jobs/JobsHostedService.cs @@ -58,6 +58,13 @@ public class JobsHostedService : BaseJobsHostedService .StartNow() .WithCronSchedule("0 0 * * * ?") .Build(); + var updatePhishingDomainsTrigger = TriggerBuilder.Create() + .WithIdentity("UpdatePhishingDomainsTrigger") + .StartNow() + .WithSimpleSchedule(x => x + .WithIntervalInHours(24) + .RepeatForever()) + .Build(); var jobs = new List> @@ -68,6 +75,7 @@ public class JobsHostedService : BaseJobsHostedService new Tuple(typeof(ValidateUsersJob), everyTopOfTheSixthHourTrigger), new Tuple(typeof(ValidateOrganizationsJob), everyTwelfthHourAndThirtyMinutesTrigger), new Tuple(typeof(ValidateOrganizationDomainJob), validateOrganizationDomainTrigger), + new Tuple(typeof(UpdatePhishingDomainsJob), updatePhishingDomainsTrigger), }; if (_globalSettings.SelfHosted && _globalSettings.EnableCloudCommunication) @@ -96,6 +104,7 @@ public class JobsHostedService : BaseJobsHostedService services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); } public static void AddCommercialSecretsManagerJobServices(IServiceCollection services) diff --git a/src/Api/Jobs/UpdatePhishingDomainsJob.cs b/src/Api/Jobs/UpdatePhishingDomainsJob.cs new file mode 100644 index 0000000000..3837013898 --- /dev/null +++ b/src/Api/Jobs/UpdatePhishingDomainsJob.cs @@ -0,0 +1,87 @@ +using Bit.Core; +using Bit.Core.Jobs; +using Bit.Core.PhishingDomainFeatures.Interfaces; +using Bit.Core.Repositories; +using Bit.Core.Settings; +using Quartz; + +namespace Bit.Api.Jobs; + +public class UpdatePhishingDomainsJob : BaseJob +{ + private readonly GlobalSettings _globalSettings; + private readonly IPhishingDomainRepository _phishingDomainRepository; + private readonly ICloudPhishingDomainQuery _cloudPhishingDomainQuery; + + public UpdatePhishingDomainsJob( + GlobalSettings globalSettings, + IPhishingDomainRepository phishingDomainRepository, + ICloudPhishingDomainQuery cloudPhishingDomainQuery, + ILogger logger) + : base(logger) + { + _globalSettings = globalSettings; + _phishingDomainRepository = phishingDomainRepository; + _cloudPhishingDomainQuery = cloudPhishingDomainQuery; + } + + protected override async Task ExecuteJobAsync(IJobExecutionContext context) + { + if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.UpdateUrl)) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Skipping phishing domain update. No URL configured."); + return; + } + + if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Skipping phishing domain update. Cloud communication is disabled in global settings."); + return; + } + + // Get the remote checksum + var remoteChecksum = await _cloudPhishingDomainQuery.GetRemoteChecksumAsync(); + if (string.IsNullOrWhiteSpace(remoteChecksum)) + { + _logger.LogWarning(Constants.BypassFiltersEventId, "Could not retrieve remote checksum. Skipping update."); + return; + } + + // Get the current checksum from the database + var currentChecksum = await _phishingDomainRepository.GetCurrentChecksumAsync(); + + // Compare checksums to determine if update is needed + if (string.Equals(currentChecksum, remoteChecksum, StringComparison.OrdinalIgnoreCase)) + { + _logger.LogInformation(Constants.BypassFiltersEventId, + "Phishing domains list is up to date (checksum: {Checksum}). Skipping update.", + currentChecksum); + return; + } + + _logger.LogInformation(Constants.BypassFiltersEventId, + "Checksums differ (current: {CurrentChecksum}, remote: {RemoteChecksum}). Fetching updated domains from {Source}.", + currentChecksum, remoteChecksum, _globalSettings.SelfHosted ? "Bitwarden cloud API" : "external source"); + + try + { + var domains = await _cloudPhishingDomainQuery.GetPhishingDomainsAsync(); + + if (domains.Count > 0) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Updating {Count} phishing domains with checksum {Checksum}.", + domains.Count, remoteChecksum); + await _phishingDomainRepository.UpdatePhishingDomainsAsync(domains, remoteChecksum); + _logger.LogInformation(Constants.BypassFiltersEventId, "Successfully updated phishing domains."); + } + else + { + _logger.LogWarning(Constants.BypassFiltersEventId, "No valid domains found in the response. Skipping update."); + } + } + catch (Exception ex) + { + _logger.LogError(Constants.BypassFiltersEventId, ex, "Error updating phishing domains."); + } + } +} diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index 5849bfb634..0b80625175 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -177,6 +177,7 @@ public class Startup services.AddBillingOperations(); services.AddReportingServices(); services.AddImportServices(); + services.AddPhishingDomainServices(globalSettings); // Authorization Handlers services.AddAuthorizationHandlers(); diff --git a/src/Api/Utilities/ServiceCollectionExtensions.cs b/src/Api/Utilities/ServiceCollectionExtensions.cs index feeac03e54..2280b576cb 100644 --- a/src/Api/Utilities/ServiceCollectionExtensions.cs +++ b/src/Api/Utilities/ServiceCollectionExtensions.cs @@ -2,6 +2,8 @@ using Bit.Api.Vault.AuthorizationHandlers.Collections; using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Authorization; using Bit.Core.IdentityServer; +using Bit.Core.PhishingDomainFeatures; +using Bit.Core.PhishingDomainFeatures.Interfaces; using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.Core.Vault.Authorization.SecurityTasks; @@ -106,4 +108,22 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddScoped(); } + + public static void AddPhishingDomainServices(this IServiceCollection services, GlobalSettings globalSettings) + { + services.AddHttpClient("PhishingDomains", client => + { + client.DefaultRequestHeaders.Add("User-Agent", globalSettings.SelfHosted ? "Bitwarden Self-Hosted" : "Bitwarden"); + client.Timeout = TimeSpan.FromSeconds(30); + }); + + if (globalSettings.SelfHosted) + { + services.AddScoped(); + } + else + { + services.AddScoped(); + } + } } diff --git a/src/Api/appsettings.Development.json b/src/Api/appsettings.Development.json index 2f33d87ae8..82fb951261 100644 --- a/src/Api/appsettings.Development.json +++ b/src/Api/appsettings.Development.json @@ -37,6 +37,10 @@ }, "storage": { "connectionString": "UseDevelopmentStorage=true" + }, + "phishingDomain": { + "updateUrl": "https://phish.co.za/latest/phishing-domains-ACTIVE.txt", + "checksumUrl": "https://raw.githubusercontent.com/Phishing-Database/checksums/refs/heads/master/phishing-domains-ACTIVE.txt.sha256" } } } diff --git a/src/Api/appsettings.json b/src/Api/appsettings.json index 98b210cb1e..f8a69dcfac 100644 --- a/src/Api/appsettings.json +++ b/src/Api/appsettings.json @@ -71,6 +71,9 @@ "accessKeySecret": "SECRET", "region": "SECRET" }, + "phishingDomain": { + "updateUrl": "SECRET" + }, "distributedIpRateLimiting": { "enabled": true, "maxRedisTimeoutsThreshold": 10, diff --git a/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs b/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs new file mode 100644 index 0000000000..cee741b23b --- /dev/null +++ b/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs @@ -0,0 +1,104 @@ +using Bit.Core.PhishingDomainFeatures.Interfaces; +using Bit.Core.Settings; +using Microsoft.Extensions.Logging; + +namespace Bit.Core.PhishingDomainFeatures; + +/// +/// Implementation of ICloudPhishingDomainQuery for cloud environments +/// that directly calls the external phishing domain source +/// +public class CloudPhishingDomainDirectQuery : ICloudPhishingDomainQuery +{ + private readonly IGlobalSettings _globalSettings; + private readonly IHttpClientFactory _httpClientFactory; + private readonly ILogger _logger; + + public CloudPhishingDomainDirectQuery( + IGlobalSettings globalSettings, + IHttpClientFactory httpClientFactory, + ILogger logger) + { + _globalSettings = globalSettings; + _httpClientFactory = httpClientFactory; + _logger = logger; + } + + public async Task> GetPhishingDomainsAsync() + { + if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.UpdateUrl)) + { + throw new InvalidOperationException("Phishing domain update URL is not configured."); + } + + var httpClient = _httpClientFactory.CreateClient("PhishingDomains"); + var response = await httpClient.GetAsync(_globalSettings.PhishingDomain.UpdateUrl); + response.EnsureSuccessStatusCode(); + + var content = await response.Content.ReadAsStringAsync(); + return ParseDomains(content); + } + + /// + /// Gets the SHA256 checksum of the remote phishing domains list + /// + /// The SHA256 checksum as a lowercase hex string + public async Task GetRemoteChecksumAsync() + { + if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.ChecksumUrl)) + { + _logger.LogWarning("Phishing domain checksum URL is not configured."); + return string.Empty; + } + + try + { + var httpClient = _httpClientFactory.CreateClient("PhishingDomains"); + var response = await httpClient.GetAsync(_globalSettings.PhishingDomain.ChecksumUrl); + response.EnsureSuccessStatusCode(); + + var content = await response.Content.ReadAsStringAsync(); + return ParseChecksumResponse(content); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error retrieving phishing domain checksum from {Url}", + _globalSettings.PhishingDomain.ChecksumUrl); + return string.Empty; + } + } + + /// + /// Parses a checksum response in the format "hash *filename" + /// + private static string ParseChecksumResponse(string checksumContent) + { + if (string.IsNullOrWhiteSpace(checksumContent)) + { + return string.Empty; + } + + // Format is typically "hash *filename" + var parts = checksumContent.Split(' ', 2); + if (parts.Length > 0) + { + return parts[0].Trim(); + } + + return string.Empty; + } + + private static List ParseDomains(string content) + { + if (string.IsNullOrWhiteSpace(content)) + { + return []; + } + + return content + .Split(['\r', '\n'], StringSplitOptions.RemoveEmptyEntries) + .Select(line => line.Trim()) + .Where(line => !string.IsNullOrWhiteSpace(line) && !line.StartsWith("#")) + .ToList(); + } +} diff --git a/src/Core/PhishingDomainFeatures/CloudPhishingDomainRelayQuery.cs b/src/Core/PhishingDomainFeatures/CloudPhishingDomainRelayQuery.cs new file mode 100644 index 0000000000..da279dab9a --- /dev/null +++ b/src/Core/PhishingDomainFeatures/CloudPhishingDomainRelayQuery.cs @@ -0,0 +1,66 @@ +using Bit.Core.PhishingDomainFeatures.Interfaces; +using Bit.Core.Services; +using Bit.Core.Settings; +using Microsoft.Extensions.Logging; + +namespace Bit.Core.PhishingDomainFeatures; + +/// +/// Implementation of ICloudPhishingDomainQuery for self-hosted environments +/// that relays the request to the Bitwarden cloud API +/// +public class CloudPhishingDomainRelayQuery : BaseIdentityClientService, ICloudPhishingDomainQuery +{ + private readonly IGlobalSettings _globalSettings; + + public CloudPhishingDomainRelayQuery( + IHttpClientFactory httpFactory, + IGlobalSettings globalSettings, + ILogger logger) + : base( + httpFactory, + globalSettings.Installation.ApiUri, + globalSettings.Installation.IdentityUri, + "api.installation", + $"installation.{globalSettings.Installation.Id}", + globalSettings.Installation.Key, + logger) + { + _globalSettings = globalSettings; + } + + public async Task> GetPhishingDomainsAsync() + { + if (!_globalSettings.SelfHosted || !_globalSettings.EnableCloudCommunication) + { + throw new InvalidOperationException("This query is only for self-hosted installations with cloud communication enabled."); + } + + var result = await SendAsync(HttpMethod.Get, "phishing-domains", null, true); + return result?.ToList() ?? new List(); + } + + /// + /// Gets the SHA256 checksum of the remote phishing domains list + /// + /// The SHA256 checksum as a lowercase hex string + public async Task GetRemoteChecksumAsync() + { + if (!_globalSettings.SelfHosted || !_globalSettings.EnableCloudCommunication) + { + throw new InvalidOperationException("This query is only for self-hosted installations with cloud communication enabled."); + } + + try + { + // For self-hosted environments, we get the checksum from the Bitwarden cloud API + var result = await SendAsync(HttpMethod.Get, "phishing-domains/checksum", null, true); + return result ?? string.Empty; + } + catch (Exception ex) + { + _logger.LogError(ex, "Error retrieving phishing domain checksum from Bitwarden cloud API"); + return string.Empty; + } + } +} diff --git a/src/Core/PhishingDomainFeatures/Interfaces/ICloudPhishingDomainQuery.cs b/src/Core/PhishingDomainFeatures/Interfaces/ICloudPhishingDomainQuery.cs new file mode 100644 index 0000000000..dac91747f7 --- /dev/null +++ b/src/Core/PhishingDomainFeatures/Interfaces/ICloudPhishingDomainQuery.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.PhishingDomainFeatures.Interfaces; + +public interface ICloudPhishingDomainQuery +{ + Task> GetPhishingDomainsAsync(); + Task GetRemoteChecksumAsync(); +} diff --git a/src/Core/Repositories/IPhishingDomainRepository.cs b/src/Core/Repositories/IPhishingDomainRepository.cs index aea6626ab4..2d653b0a43 100644 --- a/src/Core/Repositories/IPhishingDomainRepository.cs +++ b/src/Core/Repositories/IPhishingDomainRepository.cs @@ -3,5 +3,6 @@ public interface IPhishingDomainRepository { Task> GetActivePhishingDomainsAsync(); - Task UpdatePhishingDomainsAsync(IEnumerable domains); + Task UpdatePhishingDomainsAsync(IEnumerable domains, string checksum); + Task GetCurrentChecksumAsync(); } diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index 6bb76eb50a..d6831b255a 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -84,6 +84,7 @@ public class GlobalSettings : IGlobalSettings public virtual ILaunchDarklySettings LaunchDarkly { get; set; } = new LaunchDarklySettings(); public virtual string DevelopmentDirectory { get; set; } public virtual IWebPushSettings WebPush { get; set; } = new WebPushSettings(); + public virtual IPhishingDomainSettings PhishingDomain { get; set; } = new PhishingDomainSettings(); public virtual bool EnableEmailVerification { get; set; } public virtual string KdfDefaultHashKey { get; set; } @@ -634,6 +635,12 @@ public class GlobalSettings : IGlobalSettings public int MaxNetworkRetries { get; set; } = 2; } + public class PhishingDomainSettings : IPhishingDomainSettings + { + public string UpdateUrl { get; set; } + public string ChecksumUrl { get; set; } + } + public class DistributedIpRateLimitingSettings { public string RedisConnectionString { get; set; } diff --git a/src/Core/Settings/IGlobalSettings.cs b/src/Core/Settings/IGlobalSettings.cs index 411014ea32..d77842373e 100644 --- a/src/Core/Settings/IGlobalSettings.cs +++ b/src/Core/Settings/IGlobalSettings.cs @@ -29,4 +29,5 @@ public interface IGlobalSettings string DevelopmentDirectory { get; set; } IWebPushSettings WebPush { get; set; } GlobalSettings.EventLoggingSettings EventLogging { get; set; } + IPhishingDomainSettings PhishingDomain { get; set; } } diff --git a/src/Core/Settings/IPhishingDomainSettings.cs b/src/Core/Settings/IPhishingDomainSettings.cs new file mode 100644 index 0000000000..2e4a901a5a --- /dev/null +++ b/src/Core/Settings/IPhishingDomainSettings.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.Settings; + +public interface IPhishingDomainSettings +{ + string UpdateUrl { get; set; } + string ChecksumUrl { get; set; } +} diff --git a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs index 228e62b9fa..d8a0b52c47 100644 --- a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs +++ b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs @@ -72,10 +72,4 @@ public static class DapperServiceCollectionExtensions services.AddSingleton(); } } - - public static void AddDapper(this IServiceCollection services) - { - // Register repositories - services.AddSingleton(); - } } diff --git a/src/Infrastructure.Dapper/Repositories/PhishingDomainRepository.cs b/src/Infrastructure.Dapper/Repositories/PhishingDomainRepository.cs index 5198b34372..9a732d4987 100644 --- a/src/Infrastructure.Dapper/Repositories/PhishingDomainRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/PhishingDomainRepository.cs @@ -5,6 +5,7 @@ using Bit.Core.Settings; using Dapper; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Logging; namespace Bit.Infrastructure.Dapper.Repositories; @@ -12,77 +13,152 @@ public class PhishingDomainRepository : IPhishingDomainRepository { private readonly string _connectionString; private readonly IDistributedCache _cache; - private const string _cacheKey = "PhishingDomains"; - private static readonly DistributedCacheEntryOptions _cacheOptions = new DistributedCacheEntryOptions + private readonly ILogger _logger; + private const string _cacheKey = "PhishingDomains_v1"; + private static readonly DistributedCacheEntryOptions _cacheOptions = new() { - AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(24) // Cache for 24 hours + AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(24), + SlidingExpiration = TimeSpan.FromHours(1) }; - public PhishingDomainRepository(GlobalSettings globalSettings, IDistributedCache cache) - : this(globalSettings.SqlServer.ConnectionString, cache) + public PhishingDomainRepository( + GlobalSettings globalSettings, + IDistributedCache cache, + ILogger logger) + : this(globalSettings.SqlServer.ConnectionString, cache, logger) { } - public PhishingDomainRepository(string connectionString, IDistributedCache cache) + public PhishingDomainRepository( + string connectionString, + IDistributedCache cache, + ILogger logger) { _connectionString = connectionString; _cache = cache; + _logger = logger; } public async Task> GetActivePhishingDomainsAsync() { - // Try to get from cache first - var cachedDomains = await _cache.GetStringAsync(_cacheKey); - if (!string.IsNullOrEmpty(cachedDomains)) + try { - return JsonSerializer.Deserialize>(cachedDomains) ?? new List(); + var cachedDomains = await _cache.GetStringAsync(_cacheKey); + if (!string.IsNullOrEmpty(cachedDomains)) + { + _logger.LogDebug("Retrieved phishing domains from cache"); + return JsonSerializer.Deserialize>(cachedDomains) ?? []; + } + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to retrieve phishing domains from cache"); } - // If not in cache, get from database - using (var connection = new SqlConnection(_connectionString)) + await using var connection = new SqlConnection(_connectionString); + + var results = await connection.QueryAsync( + "[dbo].[PhishingDomain_ReadAll]", + commandType: CommandType.StoredProcedure); + + var domains = results.AsList(); + + try { - var results = await connection.QueryAsync( - "[dbo].[PhishingDomain_ReadAll]", - commandType: CommandType.StoredProcedure); - - var domains = results.AsList(); - - // Store in cache await _cache.SetStringAsync( _cacheKey, JsonSerializer.Serialize(domains), _cacheOptions); - return domains; + _logger.LogDebug("Stored {Count} phishing domains in cache", domains.Count); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to store phishing domains in cache"); + } + + return domains; + } + + public async Task GetCurrentChecksumAsync() + { + try + { + await using var connection = new SqlConnection(_connectionString); + + var checksum = await connection.QueryFirstOrDefaultAsync( + "[dbo].[PhishingDomain_ReadChecksum]", + commandType: CommandType.StoredProcedure); + + return checksum ?? string.Empty; + } + catch (Exception ex) + { + _logger.LogError(ex, "Error retrieving phishing domain checksum from database"); + return string.Empty; } } - public async Task UpdatePhishingDomainsAsync(IEnumerable domains) + public async Task UpdatePhishingDomainsAsync(IEnumerable domains, string checksum) { - using (var connection = new SqlConnection(_connectionString)) + var domainsList = domains.ToList(); + _logger.LogInformation("Beginning bulk update of {Count} phishing domains with checksum {Checksum}", + domainsList.Count, checksum); + + await using var connection = new SqlConnection(_connectionString); + await connection.OpenAsync(); + + await using var transaction = connection.BeginTransaction(); + try { await connection.ExecuteAsync( "[dbo].[PhishingDomain_DeleteAll]", + transaction: transaction, commandType: CommandType.StoredProcedure); - foreach (var domain in domains) + var dataTable = new DataTable(); + dataTable.Columns.Add("Id", typeof(Guid)); + dataTable.Columns.Add("Domain", typeof(string)); + dataTable.Columns.Add("Checksum", typeof(string)); + + dataTable.PrimaryKey = [dataTable.Columns["Id"]]; + + foreach (var domain in domainsList) { - await connection.ExecuteAsync( - "[dbo].[PhishingDomain_Create]", - new - { - Id = Guid.NewGuid(), - Domain = domain, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow - }, - commandType: CommandType.StoredProcedure); + dataTable.Rows.Add(Guid.NewGuid(), domain, checksum); } + + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.Default, transaction); + + bulkCopy.DestinationTableName = "[dbo].[PhishingDomain]"; + bulkCopy.BatchSize = 10000; + + bulkCopy.ColumnMappings.Add("Id", "Id"); + bulkCopy.ColumnMappings.Add("Domain", "Domain"); + bulkCopy.ColumnMappings.Add("Checksum", "Checksum"); + + await bulkCopy.WriteToServerAsync(dataTable); + await transaction.CommitAsync(); + + _logger.LogInformation("Successfully bulk updated {Count} phishing domains", domainsList.Count); + } + catch (Exception ex) + { + await transaction.RollbackAsync(); + _logger.LogError(ex, "Failed to bulk update phishing domains"); + throw; } - // Update cache with new domains - await _cache.SetStringAsync( - _cacheKey, - JsonSerializer.Serialize(domains), - _cacheOptions); + try + { + await _cache.SetStringAsync( + _cacheKey, + JsonSerializer.Serialize(domainsList), + _cacheOptions); + _logger.LogDebug("Updated phishing domains cache after update operation"); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to update phishing domains in cache"); + } } } diff --git a/src/Infrastructure.EntityFramework/Models/PhishingDomain.cs b/src/Infrastructure.EntityFramework/Models/PhishingDomain.cs index 842939d022..e11897ec38 100644 --- a/src/Infrastructure.EntityFramework/Models/PhishingDomain.cs +++ b/src/Infrastructure.EntityFramework/Models/PhishingDomain.cs @@ -11,7 +11,6 @@ public class PhishingDomain [MaxLength(255)] public string Domain { get; set; } - public DateTime CreationDate { get; set; } - - public DateTime RevisionDate { get; set; } + [MaxLength(64)] + public string Checksum { get; set; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs index 17237c5207..c2f95e72f9 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs @@ -111,6 +111,7 @@ public class DatabaseContext : DbContext var eOrganizationConnection = builder.Entity(); var eOrganizationDomain = builder.Entity(); var aWebAuthnCredential = builder.Entity(); + var ePhishingDomain = builder.Entity(); // Shadow property configurations go here @@ -127,6 +128,7 @@ public class DatabaseContext : DbContext eOrganizationConnection.Property(c => c.Id).ValueGeneratedNever(); eOrganizationDomain.Property(ar => ar.Id).ValueGeneratedNever(); aWebAuthnCredential.Property(ar => ar.Id).ValueGeneratedNever(); + ePhishingDomain.Property(ar => ar.Id).ValueGeneratedNever(); eCollectionCipher.HasKey(cc => new { cc.CollectionId, cc.CipherId }); eCollectionUser.HasKey(cu => new { cu.CollectionId, cu.OrganizationUserId }); @@ -167,6 +169,7 @@ public class DatabaseContext : DbContext eOrganizationConnection.ToTable(nameof(OrganizationConnection)); eOrganizationDomain.ToTable(nameof(OrganizationDomain)); aWebAuthnCredential.ToTable(nameof(WebAuthnCredential)); + ePhishingDomain.ToTable(nameof(PhishingDomain)); ConfigureDateTimeUtcQueries(builder); } diff --git a/src/Infrastructure.EntityFramework/Repositories/PhishingDomainRepository.cs b/src/Infrastructure.EntityFramework/Repositories/PhishingDomainRepository.cs index 7266f01e7c..d0368ad809 100644 --- a/src/Infrastructure.EntityFramework/Repositories/PhishingDomainRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/PhishingDomainRepository.cs @@ -1,50 +1,167 @@ -using Bit.Core.Repositories; -using Bit.Infrastructure.EntityFramework.Models; +using System.Data; +using System.Text.Json; +using Bit.Core.Repositories; +using Microsoft.Data.SqlClient; using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; namespace Bit.Infrastructure.EntityFramework.Repositories; public class PhishingDomainRepository : IPhishingDomainRepository { private readonly IServiceScopeFactory _serviceScopeFactory; + private readonly IDistributedCache _cache; + private readonly ILogger _logger; + private const string _cacheKey = "PhishingDomains_v1"; + private static readonly DistributedCacheEntryOptions _cacheOptions = new() + { + AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(24), + SlidingExpiration = TimeSpan.FromHours(1) + }; - public PhishingDomainRepository(IServiceScopeFactory serviceScopeFactory) + public PhishingDomainRepository( + IServiceScopeFactory serviceScopeFactory, + IDistributedCache cache, + ILogger logger) { _serviceScopeFactory = serviceScopeFactory; + _cache = cache; + _logger = logger; } public async Task> GetActivePhishingDomainsAsync() { - using (var scope = _serviceScopeFactory.CreateScope()) + try { + var cachedDomains = await _cache.GetStringAsync(_cacheKey); + + if (!string.IsNullOrEmpty(cachedDomains)) + { + _logger.LogDebug("Retrieved phishing domains from cache"); + return JsonSerializer.Deserialize>(cachedDomains) ?? []; + } + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to retrieve phishing domains from cache"); + } + + using var scope = _serviceScopeFactory.CreateScope(); + + var dbContext = scope.ServiceProvider.GetRequiredService(); + var domains = await dbContext.PhishingDomains + .Select(d => d.Domain) + .ToListAsync(); + + try + { + await _cache.SetStringAsync( + _cacheKey, + JsonSerializer.Serialize(domains), + _cacheOptions); + + _logger.LogDebug("Stored {Count} phishing domains in cache", domains.Count); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to store phishing domains in cache"); + } + + return domains; + } + + public async Task GetCurrentChecksumAsync() + { + try + { + using var scope = _serviceScopeFactory.CreateScope(); var dbContext = scope.ServiceProvider.GetRequiredService(); - var domains = await dbContext.PhishingDomains - .Select(d => d.Domain) - .ToListAsync(); - return domains; + + // Get the first checksum in the database (there should only be one set of domains with the same checksum) + var checksum = await dbContext.PhishingDomains + .Select(d => d.Checksum) + .FirstOrDefaultAsync(); + + return checksum ?? string.Empty; + } + catch (Exception ex) + { + _logger.LogError(ex, "Error retrieving phishing domain checksum from database"); + return string.Empty; } } - public async Task UpdatePhishingDomainsAsync(IEnumerable domains) + public async Task UpdatePhishingDomainsAsync(IEnumerable domains, string checksum) { - using (var scope = _serviceScopeFactory.CreateScope()) + var domainsList = domains.ToList(); + _logger.LogInformation("Beginning bulk update of {Count} phishing domains with checksum {Checksum}", + domainsList.Count, checksum); + + using var scope = _serviceScopeFactory.CreateScope(); + var dbContext = scope.ServiceProvider.GetRequiredService(); + + var connection = dbContext.Database.GetDbConnection(); + var connectionString = connection.ConnectionString; + + await using var sqlConnection = new SqlConnection(connectionString); + await sqlConnection.OpenAsync(); + + await using var transaction = sqlConnection.BeginTransaction(); + try { - var dbContext = scope.ServiceProvider.GetRequiredService(); + await using var command = sqlConnection.CreateCommand(); + command.Transaction = transaction; + command.CommandText = "[dbo].[PhishingDomain_DeleteAll]"; + command.CommandType = CommandType.StoredProcedure; + await command.ExecuteNonQueryAsync(); - // Clear existing domains - await dbContext.PhishingDomains.ExecuteDeleteAsync(); + var dataTable = new DataTable(); + dataTable.Columns.Add("Id", typeof(Guid)); + dataTable.Columns.Add("Domain", typeof(string)); + dataTable.Columns.Add("Checksum", typeof(string)); - // Add new domains - var phishingDomains = domains.Select(d => new PhishingDomain + dataTable.PrimaryKey = [dataTable.Columns["Id"]]; + + foreach (var domain in domainsList) { - Id = Guid.NewGuid(), - Domain = d, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow - }); - await dbContext.PhishingDomains.AddRangeAsync(phishingDomains); - await dbContext.SaveChangesAsync(); + dataTable.Rows.Add(Guid.NewGuid(), domain, checksum); + } + + using var bulkCopy = new SqlBulkCopy(sqlConnection, SqlBulkCopyOptions.Default, transaction); + + bulkCopy.DestinationTableName = "[dbo].[PhishingDomain]"; + bulkCopy.BatchSize = 10000; + + bulkCopy.ColumnMappings.Add("Id", "Id"); + bulkCopy.ColumnMappings.Add("Domain", "Domain"); + bulkCopy.ColumnMappings.Add("Checksum", "Checksum"); + + await bulkCopy.WriteToServerAsync(dataTable); + await transaction.CommitAsync(); + + _logger.LogInformation("Successfully bulk updated {Count} phishing domains", domainsList.Count); + } + catch (Exception ex) + { + await transaction.RollbackAsync(); + _logger.LogError(ex, "Failed to bulk update phishing domains"); + throw; + } + + try + { + await _cache.SetStringAsync( + _cacheKey, + JsonSerializer.Serialize(domainsList), + _cacheOptions); + + _logger.LogDebug("Updated phishing domains cache after update operation"); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to update phishing domains in cache"); } } } diff --git a/src/Sql/dbo/Stored Procedures/PhishingDomain_Create.sql b/src/Sql/dbo/Stored Procedures/PhishingDomain_Create.sql index acb2da8876..fe29dffcd9 100644 --- a/src/Sql/dbo/Stored Procedures/PhishingDomain_Create.sql +++ b/src/Sql/dbo/Stored Procedures/PhishingDomain_Create.sql @@ -1,8 +1,7 @@ CREATE PROCEDURE [dbo].[PhishingDomain_Create] @Id UNIQUEIDENTIFIER, @Domain NVARCHAR(255), - @CreationDate DATETIME2(7), - @RevisionDate DATETIME2(7) + @Checksum NVARCHAR(64) AS BEGIN SET NOCOUNT ON @@ -11,14 +10,12 @@ BEGIN ( [Id], [Domain], - [CreationDate], - [RevisionDate] + [Checksum] ) VALUES ( @Id, @Domain, - @CreationDate, - @RevisionDate + @Checksum ) END \ No newline at end of file diff --git a/src/Sql/dbo/Stored Procedures/PhishingDomain_ReadChecksum.sql b/src/Sql/dbo/Stored Procedures/PhishingDomain_ReadChecksum.sql new file mode 100644 index 0000000000..83e23a8c1e --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/PhishingDomain_ReadChecksum.sql @@ -0,0 +1,10 @@ +CREATE PROCEDURE [dbo].[PhishingDomain_ReadChecksum] +AS +BEGIN + SET NOCOUNT ON + + SELECT TOP 1 + [Checksum] + FROM + [dbo].[PhishingDomain] +END \ No newline at end of file diff --git a/src/Sql/dbo/Tables/PhishingDomain.sql b/src/Sql/dbo/Tables/PhishingDomain.sql index fe5686ee68..f816666d6e 100644 --- a/src/Sql/dbo/Tables/PhishingDomain.sql +++ b/src/Sql/dbo/Tables/PhishingDomain.sql @@ -1,8 +1,7 @@ CREATE TABLE [dbo].[PhishingDomain] ( - [Id] UNIQUEIDENTIFIER NOT NULL, - [Domain] NVARCHAR(255) NOT NULL, - [CreationDate] DATETIME2(7) NOT NULL, - [RevisionDate] DATETIME2(7) NOT NULL, + [Id] UNIQUEIDENTIFIER NOT NULL, + [Domain] NVARCHAR(255) NOT NULL, + [Checksum] NVARCHAR(64) NULL, CONSTRAINT [PK_PhishingDomain] PRIMARY KEY CLUSTERED ([Id] ASC) ); diff --git a/util/Migrator/DbScripts/2024-05-17_00_PhishingDomainChecksum.sql b/util/Migrator/DbScripts/2024-05-17_00_PhishingDomainChecksum.sql new file mode 100644 index 0000000000..39823b3da7 --- /dev/null +++ b/util/Migrator/DbScripts/2024-05-17_00_PhishingDomainChecksum.sql @@ -0,0 +1,61 @@ +-- Update PhishingDomain table to use Checksum instead of dates +IF EXISTS (SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'PhishingDomain' AND COLUMN_NAME = 'CreationDate') +BEGIN + -- Add Checksum column + ALTER TABLE [dbo].[PhishingDomain] + ADD [Checksum] NVARCHAR(64) NULL; + + -- Drop old columns + ALTER TABLE [dbo].[PhishingDomain] + DROP COLUMN [CreationDate], [RevisionDate]; +END +GO + +-- Update PhishingDomain_Create stored procedure +IF OBJECT_ID('[dbo].[PhishingDomain_Create]') IS NOT NULL +BEGIN + DROP PROCEDURE [dbo].[PhishingDomain_Create] +END +GO + +CREATE PROCEDURE [dbo].[PhishingDomain_Create] + @Id UNIQUEIDENTIFIER, + @Domain NVARCHAR(255), + @Checksum NVARCHAR(64) +AS +BEGIN + SET NOCOUNT ON + + INSERT INTO [dbo].[PhishingDomain] + ( + [Id], + [Domain], + [Checksum] + ) + VALUES + ( + @Id, + @Domain, + @Checksum + ) +END +GO + +-- Create PhishingDomain_ReadChecksum stored procedure +IF OBJECT_ID('[dbo].[PhishingDomain_ReadChecksum]') IS NOT NULL +BEGIN + DROP PROCEDURE [dbo].[PhishingDomain_ReadChecksum] +END +GO + +CREATE PROCEDURE [dbo].[PhishingDomain_ReadChecksum] +AS +BEGIN + SET NOCOUNT ON + + SELECT TOP 1 + [Checksum] + FROM + [dbo].[PhishingDomain] +END +GO \ No newline at end of file