diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 8d1f009d40..3a7b654438 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -9,6 +9,8 @@ public static class Constants // in nginx/proxy.conf may also need to be updated accordingly. public const long FileSize101mb = 101L * 1024L * 1024L; public const long FileSize501mb = 501L * 1024L * 1024L; + public const string DatabaseFieldProtectorPurpose = "DatabaseFieldProtection"; + public const string DatabaseFieldProtectedPrefix = "P|"; } public static class TokenPurposes diff --git a/src/Infrastructure.Dapper/Repositories/UserRepository.cs b/src/Infrastructure.Dapper/Repositories/UserRepository.cs index 7272d8a251..feef8a8edb 100644 --- a/src/Infrastructure.Dapper/Repositories/UserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/UserRepository.cs @@ -1,18 +1,26 @@ using System.Data; +using Bit.Core; using Bit.Core.Entities; using Bit.Core.Models.Data; using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; +using Microsoft.AspNetCore.DataProtection; using Microsoft.Data.SqlClient; namespace Bit.Infrastructure.Dapper.Repositories; public class UserRepository : Repository, IUserRepository { - public UserRepository(GlobalSettings globalSettings) + private readonly IDataProtector _dataProtector; + + public UserRepository( + GlobalSettings globalSettings, + IDataProtectionProvider dataProtectionProvider) : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } + { + _dataProtector = dataProtectionProvider.CreateProtector(Constants.DatabaseFieldProtectorPurpose); + } public UserRepository(string connectionString, string readOnlyConnectionString) : base(connectionString, readOnlyConnectionString) @@ -20,7 +28,9 @@ public class UserRepository : Repository, IUserRepository public override async Task GetByIdAsync(Guid id) { - return await base.GetByIdAsync(id); + var user = await base.GetByIdAsync(id); + UnprotectData(user); + return user; } public async Task GetByEmailAsync(string email) @@ -32,6 +42,7 @@ public class UserRepository : Repository, IUserRepository new { Email = email }, commandType: CommandType.StoredProcedure); + UnprotectData(results); return results.SingleOrDefault(); } } @@ -45,6 +56,7 @@ public class UserRepository : Repository, IUserRepository new { OrganizationId = organizationId, ExternalId = externalId }, commandType: CommandType.StoredProcedure); + UnprotectData(results); return results.SingleOrDefault(); } } @@ -72,6 +84,7 @@ public class UserRepository : Repository, IUserRepository commandType: CommandType.StoredProcedure, commandTimeout: 120); + UnprotectData(results); return results.ToList(); } } @@ -85,6 +98,7 @@ public class UserRepository : Repository, IUserRepository new { Premium = premium }, commandType: CommandType.StoredProcedure); + UnprotectData(results); return results.ToList(); } } @@ -115,9 +129,15 @@ public class UserRepository : Repository, IUserRepository } } + public override async Task CreateAsync(User user) + { + await ProtectDataAndSaveAsync(user, async () => await base.CreateAsync(user)); + return user; + } + public override async Task ReplaceAsync(User user) { - await base.ReplaceAsync(user); + await ProtectDataAndSaveAsync(user, async () => await base.ReplaceAsync(user)); } public override async Task DeleteAsync(User user) @@ -164,7 +184,74 @@ public class UserRepository : Repository, IUserRepository new { Ids = ids.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); + UnprotectData(results); return results.ToList(); } } + + private async Task ProtectDataAndSaveAsync(User user, Func saveTask) + { + if (user == null) + { + await saveTask(); + return; + } + + // Capture original values + var originalMasterPassword = user.MasterPassword; + var originalKey = user.Key; + + // Protect values + if (!user.MasterPassword?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false) + { + user.MasterPassword = string.Concat(Constants.DatabaseFieldProtectedPrefix, + _dataProtector.Protect(user.MasterPassword)); + } + + if (!user.Key?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false) + { + user.Key = string.Concat(Constants.DatabaseFieldProtectedPrefix, + _dataProtector.Protect(user.Key)); + } + + // Save + await saveTask(); + + // Restore original values + user.MasterPassword = originalMasterPassword; + user.Key = originalKey; + } + + private void UnprotectData(User user) + { + if (user == null) + { + return; + } + + if (user.MasterPassword?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false) + { + user.MasterPassword = _dataProtector.Unprotect( + user.MasterPassword.Substring(Constants.DatabaseFieldProtectedPrefix.Length)); + } + + if (user.Key?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false) + { + user.Key = _dataProtector.Unprotect( + user.Key.Substring(Constants.DatabaseFieldProtectedPrefix.Length)); + } + } + + private void UnprotectData(IEnumerable users) + { + if (users == null) + { + return; + } + + foreach (var user in users) + { + UnprotectData(user); + } + } } diff --git a/src/Infrastructure.EntityFramework/Converters/DataProtectionConverter.cs b/src/Infrastructure.EntityFramework/Converters/DataProtectionConverter.cs new file mode 100644 index 0000000000..ee5c23fa71 --- /dev/null +++ b/src/Infrastructure.EntityFramework/Converters/DataProtectionConverter.cs @@ -0,0 +1,33 @@ +using Bit.Core; +using Microsoft.AspNetCore.DataProtection; +using Microsoft.EntityFrameworkCore.Storage.ValueConversion; + +namespace Bit.Infrastructure.EntityFramework.Converters; +public class DataProtectionConverter : ValueConverter +{ + public DataProtectionConverter(IDataProtector dataProtector) : + base(s => Protect(dataProtector, s), s => Unprotect(dataProtector, s)) + { } + + private static string Protect(IDataProtector dataProtector, string value) + { + if (value?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? true) + { + return value; + } + + return string.Concat( + Constants.DatabaseFieldProtectedPrefix, dataProtector.Protect(value)); + } + + private static string Unprotect(IDataProtector dataProtector, string value) + { + if (!value?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? true) + { + return value; + } + + return dataProtector.Unprotect( + value.Substring(Constants.DatabaseFieldProtectedPrefix.Length)); + } +} diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs index c4a79d0a11..ffdcfdc24d 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs @@ -1,6 +1,10 @@ -using Bit.Infrastructure.EntityFramework.Models; +using Bit.Core; +using Bit.Infrastructure.EntityFramework.Converters; +using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Storage.ValueConversion; +using DP = Microsoft.AspNetCore.DataProtection; namespace Bit.Infrastructure.EntityFramework.Repositories; @@ -113,6 +117,12 @@ public class DatabaseContext : DbContext eGrant.HasKey(x => x.Key); eGroupUser.HasKey(gu => new { gu.GroupId, gu.OrganizationUserId }); + var dataProtector = this.GetService().CreateProtector( + Constants.DatabaseFieldProtectorPurpose); + var dataProtectionConverter = new DataProtectionConverter(dataProtector); + eUser.Property(c => c.Key).HasConversion(dataProtectionConverter); + eUser.Property(c => c.MasterPassword).HasConversion(dataProtectionConverter); + if (Database.IsNpgsql()) { // the postgres provider doesn't currently support database level non-deterministic collations. diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index 9d72201017..fef0eb4857 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -408,7 +408,7 @@ public static class ServiceCollectionExtensions public static void AddCustomDataProtectionServices( this IServiceCollection services, IWebHostEnvironment env, GlobalSettings globalSettings) { - var builder = services.AddDataProtection(options => options.ApplicationDiscriminator = "Bitwarden"); + var builder = services.AddDataProtection().SetApplicationName("Bitwarden"); if (env.IsDevelopment()) { return; @@ -433,7 +433,6 @@ public static class ServiceCollectionExtensions "dataprotection.pfx", globalSettings.DataProtection.CertificatePassword) .GetAwaiter().GetResult(); } - //TODO djsmith85 Check if this is the correct container name builder .PersistKeysToAzureBlobStorage(globalSettings.Storage.ConnectionString, "aspnet-dataprotection", "keys.xml") .ProtectKeysWithCertificate(dataProtectionCert); diff --git a/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs b/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs index b9c053c290..39f57389ac 100644 --- a/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs +++ b/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs @@ -1,8 +1,12 @@ using System.Reflection; +using System.Text; using AutoFixture; using AutoFixture.Kernel; using AutoFixture.Xunit2; +using Bit.Core; using Bit.Core.Test.Helpers.Factories; +using Microsoft.AspNetCore.DataProtection; +using Moq; namespace Bit.Test.Common.AutoFixture; @@ -15,13 +19,34 @@ public class GlobalSettingsBuilder : ISpecimenBuilder throw new ArgumentNullException(nameof(context)); } - var pi = request as ParameterInfo; var fixture = new Fixture(); - if (pi == null || pi.ParameterType != typeof(Bit.Core.Settings.GlobalSettings)) + if (request is not ParameterInfo pi) + { return new NoSpecimen(); + } - return GlobalSettingsFactory.GlobalSettings; + if (pi.ParameterType == typeof(Bit.Core.Settings.GlobalSettings)) + { + return GlobalSettingsFactory.GlobalSettings; + } + + if (pi.ParameterType == typeof(IDataProtectionProvider)) + { + var dataProtector = new Mock(); + dataProtector + .Setup(d => d.Unprotect(It.IsAny())) + .Returns(data => Encoding.UTF8.GetBytes(Constants.DatabaseFieldProtectedPrefix + Encoding.UTF8.GetString(data))); + + var dataProtectionProvider = new Mock(); + dataProtectionProvider + .Setup(x => x.CreateProtector(Constants.DatabaseFieldProtectorPurpose)) + .Returns(dataProtector.Object); + + return dataProtectionProvider.Object; + } + + return new NoSpecimen(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs b/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs index fbf0d98286..4ee68aa8ee 100644 --- a/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs +++ b/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs @@ -1,6 +1,11 @@ -using Bit.Core.Test.Helpers.Factories; +using System.Text; +using Bit.Core; +using Bit.Core.Test.Helpers.Factories; using Bit.Infrastructure.EntityFramework.Repositories; +using Microsoft.AspNetCore.DataProtection; using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Moq; namespace Bit.Infrastructure.EFIntegration.Test.Helpers; @@ -10,16 +15,39 @@ public static class DatabaseOptionsFactory static DatabaseOptionsFactory() { + var services = new ServiceCollection() + .AddSingleton(sp => + { + var dataProtector = new Mock(); + dataProtector + .Setup(d => d.Unprotect(It.IsAny())) + .Returns(data => Encoding.UTF8.GetBytes(Constants.DatabaseFieldProtectedPrefix + Encoding.UTF8.GetString(data))); + + var dataProtectionProvider = new Mock(); + dataProtectionProvider + .Setup(x => x.CreateProtector(Constants.DatabaseFieldProtectorPurpose)) + .Returns(dataProtector.Object); + + return dataProtectionProvider.Object; + }) + .BuildServiceProvider(); + var globalSettings = GlobalSettingsFactory.GlobalSettings; if (!string.IsNullOrWhiteSpace(GlobalSettingsFactory.GlobalSettings.PostgreSql?.ConnectionString)) { AppContext.SetSwitch("Npgsql.EnableLegacyTimestampBehavior", true); - Options.Add(new DbContextOptionsBuilder().UseNpgsql(globalSettings.PostgreSql.ConnectionString).Options); + Options.Add(new DbContextOptionsBuilder() + .UseNpgsql(globalSettings.PostgreSql.ConnectionString) + .UseApplicationServiceProvider(services) + .Options); } if (!string.IsNullOrWhiteSpace(GlobalSettingsFactory.GlobalSettings.MySql?.ConnectionString)) { var mySqlConnectionString = globalSettings.MySql.ConnectionString; - Options.Add(new DbContextOptionsBuilder().UseMySql(mySqlConnectionString, ServerVersion.AutoDetect(mySqlConnectionString)).Options); + Options.Add(new DbContextOptionsBuilder() + .UseMySql(mySqlConnectionString, ServerVersion.AutoDetect(mySqlConnectionString)) + .UseApplicationServiceProvider(services) + .Options); } } } diff --git a/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs b/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs index 9c32a46e7c..18ada0af05 100644 --- a/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs +++ b/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs @@ -56,10 +56,11 @@ public abstract class WebApplicationFactoryBase : WebApplicationFactory { var dbContextOptions = services.First(sd => sd.ServiceType == typeof(DbContextOptions)); services.Remove(dbContextOptions); - services.AddScoped(_ => + services.AddScoped(services => { return new DbContextOptionsBuilder() .UseInMemoryDatabase(DatabaseName) + .UseApplicationServiceProvider(services) .Options; });