diff --git a/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs b/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs index 696f4d8599..2460234f54 100644 --- a/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs +++ b/test/Infrastructure.IntegrationTest/DatabaseDataAttribute.cs @@ -1,16 +1,6 @@ using System.Reflection; -using Bit.Core.Enums; -using Bit.Core.Settings; -using Bit.Infrastructure.Dapper; -using Bit.Infrastructure.EntityFramework; -using Bit.Infrastructure.EntityFramework.Repositories; -using Bit.Infrastructure.IntegrationTest.Services; using Bit.Infrastructure.IntegrationTest.Utilities; -using Microsoft.Extensions.Caching.Distributed; -using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Time.Testing; using Xunit; using Xunit.Sdk; using Xunit.v3; @@ -19,29 +9,72 @@ namespace Bit.Infrastructure.IntegrationTest; public class DatabaseDataAttribute : DataAttribute { - public bool SelfHosted { get; set; } public bool UseFakeTimeProvider { get; set; } - public string? MigrationName { get; set; } - public override ValueTask> GetData(MethodInfo testMethod, DisposalTracker disposalTracker) + public override async ValueTask> GetData(MethodInfo testMethod, DisposalTracker disposalTracker) { - var builders = DatabaseStartup.Builders; + var customizers = GetOrderedCustomizers(testMethod); - if (builders == null) + var databases = DatabaseStartup.Databases; + + if (databases == null) { - throw new InvalidOperationException("Builders wasn't supplied, this likely means DatabaseStartup didn't run."); + throw new InvalidOperationException("Databases wasn't supplied, this likely means DatabaseStartup didn't run."); } - var theoryData = new ITheoryDataRow[builders.Count]; - for (var i = 0; i < builders.Count; i++) + var theories = new ITheoryDataRow[databases.Count]; + + for (var i = 0; i < theories.Length; i++) { - theoryData[i] = builders[i](testMethod, disposalTracker, this); + var customizationContext = new CustomizationContext(databases[i] with {}, testMethod, disposalTracker); + foreach (var customizer in customizers) + { + await customizer.CustomizeAsync(customizationContext); + } + + var isEnabled = customizationContext.Enabled ?? customizationContext.Database.Enabled; + + TheoryDataRowBase theory; + + if (!isEnabled) + { + theory = new TheoryDataRow() + .WithSkip("Not Enabled"); + } + else + { + theory = new ServiceTheoryDataRow(testMethod, disposalTracker, customizationContext.Services.BuildServiceProvider()); + } + + theory + .WithTrait("Type", customizationContext.Database.Type.ToString()) + .WithTrait("ConnectionString", customizationContext.Database.ConnectionString ?? "(none)") + .WithTestDisplayName($"{testMethod.Name}[{customizationContext.Database.Name ?? customizationContext.Database.Type.ToString()}]"); + + theories[i] = theory; } - return new(theoryData); + + return theories; } public override bool SupportsDiscoveryEnumeration() { return true; } + + private static IEnumerable GetOrderedCustomizers(MethodInfo methodInfo) + { + var assemblyAttributes = methodInfo.DeclaringType?.Assembly.GetCustomAttributes() ?? []; + var typeAttributes = methodInfo.DeclaringType?.GetCustomAttributes() ?? []; + var methodAttributes = methodInfo.GetCustomAttributes(); + + IReadOnlyCollection allAttributes = [..assemblyAttributes, ..typeAttributes, ..methodAttributes]; + + if (allAttributes.Count == 0) + { + return [DefaultCustomizerAttribute.Instance]; + } + + return allAttributes; + } } diff --git a/test/Infrastructure.IntegrationTest/Utilities/AutoMigrateAttribute.cs b/test/Infrastructure.IntegrationTest/Utilities/AutoMigrateAttribute.cs new file mode 100644 index 0000000000..e4d6a406ef --- /dev/null +++ b/test/Infrastructure.IntegrationTest/Utilities/AutoMigrateAttribute.cs @@ -0,0 +1,37 @@ + +using Bit.Core.Enums; +using Bit.Core.Utilities; +using Bit.Infrastructure.IntegrationTest.Services; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Infrastructure.IntegrationTest.Utilities; + +public class AutoMigrateAttribute : TestCustomizerAttribute +{ + public AutoMigrateAttribute(string? migrationName = null) + { + MigrationName = migrationName; + } + + public string? MigrationName { get; } + + public override Task CustomizeAsync(CustomizationContext customizationContext) + { + // Add migration services + var database = customizationContext.Database; + + if (database.Type == SupportedDatabaseProviders.SqlServer && !database.UseEf) + { + // Add migrator service + } + else + { + // Add migrator service + } + + // Build services provider early and run migrations + var sp = customizationContext.Services.BuildServiceProvider(); + var migrator = sp.GetRequiredService(); + migrator.ApplyMigration() + } +} diff --git a/test/Infrastructure.IntegrationTest/Utilities/DatabaseStartup.cs b/test/Infrastructure.IntegrationTest/Utilities/DatabaseStartup.cs index b44d987cac..10930bc814 100644 --- a/test/Infrastructure.IntegrationTest/Utilities/DatabaseStartup.cs +++ b/test/Infrastructure.IntegrationTest/Utilities/DatabaseStartup.cs @@ -1,27 +1,15 @@ -using System.Reflection; using Bit.Core.Enums; -using Bit.Core.Settings; -using Bit.Infrastructure.Dapper; -using Bit.Infrastructure.EntityFramework; -using Bit.Infrastructure.EntityFramework.Repositories; -using Bit.Infrastructure.IntegrationTest.Services; -using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Time.Testing; -using Xunit; using Xunit.v3; using Xunit.Sdk; namespace Bit.Infrastructure.IntegrationTest.Utilities; -using TheoryDataBuilder = Func; - -public class Database +public record Database { + public string? Name { get; set; } public SupportedDatabaseProviders Type { get; set; } - public string ConnectionString { get; set; } = default!; + public string? ConnectionString { get; set; } public bool UseEf { get; set; } public bool Enabled { get; set; } = true; } @@ -33,16 +21,36 @@ internal class TypedConfig public class DatabaseStartup : ITestPipelineStartup { - public static IReadOnlyList? Builders { get; private set; } + public static IReadOnlyList? Databases { get; private set; } public ValueTask StartAsync(IMessageSink diagnosticMessageSink) { - HashSet unconfiguredDatabases = + List unconfiguredDatabases = [ - SupportedDatabaseProviders.SqlServer, - SupportedDatabaseProviders.MySql, - SupportedDatabaseProviders.Postgres, - SupportedDatabaseProviders.Sqlite + new Database + { + Type = SupportedDatabaseProviders.SqlServer, + Enabled = false, + Name = "Unconfigured", + }, + new Database + { + Type = SupportedDatabaseProviders.MySql, + Enabled = false, + Name = "Unconfigured", + }, + new Database + { + Type = SupportedDatabaseProviders.Postgres, + Enabled = false, + Name = "Unconfigured", + }, + new Database + { + Type = SupportedDatabaseProviders.Sqlite, + Enabled = false, + Name = "Unconfigured", + }, ]; // Do startup things @@ -54,57 +62,24 @@ public class DatabaseStartup : ITestPipelineStartup var typedConfig = configuration.Get(); - var theories = new List(); if (typedConfig is not { Databases: var databases }) { - foreach (var unconfiguredDatabase in unconfiguredDatabases) - { - theories.Add((mi, _, _) => new TheoryDataRow() - .WithSkip("Unconfigured") - .WithTestDisplayName(TestName(mi, unconfiguredDatabase)) - .WithTrait("Type", unconfiguredDatabase.ToString())); - } + Databases = unconfiguredDatabases; return ValueTask.CompletedTask; } + var allDatabases = new List(); foreach (var database in databases) { - unconfiguredDatabases.Remove(database.Type); - if (!database.Enabled) - { - theories.Add((mi, _, _) => new TheoryDataRow() - .WithSkip($"Disabled") - .WithTestDisplayName(TestName(mi, database.Type)) - .WithTrait("Type", database.Type.ToString()) - .WithTrait("ConnectionString", database.ConnectionString)); - continue; - } - - - - // Build service provider for database - theories.Add((methodInfo, disposalTracker, databaseDataAttribute) => - { - var sp = BuildServiceProvider(databaseDataAttribute, database); - - return new ServiceTheoryDataRow(methodInfo, disposalTracker, sp) - .WithTestDisplayName(TestName(methodInfo, database.Type)) - .WithTrait("Type", database.Type.ToString()) - .WithTrait("ConnectionString", database.ConnectionString); - }); + unconfiguredDatabases.RemoveAll(db => db.Type == database.Type); + allDatabases.Add(database); } // Add entry for all still unconfigured database types - foreach (var unconfiguredDatabase in unconfiguredDatabases) - { - theories.Add((mi, _, _) => new TheoryDataRow() - .WithSkip("Not Configured") - .WithTestDisplayName(TestName(mi, unconfiguredDatabase)) - .WithTrait("Type", unconfiguredDatabase.ToString())); - } + allDatabases.AddRange(unconfiguredDatabases); - Builders = theories; + Databases = allDatabases; return ValueTask.CompletedTask; } @@ -113,73 +88,4 @@ public class DatabaseStartup : ITestPipelineStartup { return ValueTask.CompletedTask; } - - private static string TestName(MethodInfo methodInfo, SupportedDatabaseProviders database) - { - // Add containing type name to the beginning? - return $"{methodInfo.Name}({database})"; - } - - private IServiceProvider BuildServiceProvider(DatabaseDataAttribute databaseData, Database database) - { - var services = new ServiceCollection(); - services.AddLogging(builder => - { - builder.AddProvider(new XunitLoggerProvider(LogLevel.Information)); - }); - - services.AddDataProtection(); - - if (databaseData.UseFakeTimeProvider) - { - services.AddSingleton(); - } - - services.AddSingleton(database); - - if (database.Type == SupportedDatabaseProviders.SqlServer && !database.UseEf) - { - services.AddDapperRepositories(databaseData.SelfHosted); - var globalSettings = new GlobalSettings - { - DatabaseProvider = "sqlServer", - SqlServer = new GlobalSettings.SqlSettings - { - ConnectionString = database.ConnectionString, - }, - }; - services.AddSingleton(globalSettings); - services.AddSingleton(globalSettings); - services.AddDistributedSqlServerCache((options) => - { - options.ConnectionString = database.ConnectionString; - options.SchemaName = "dbo"; - options.TableName = "Cache"; - }); - - if (!string.IsNullOrEmpty(databaseData.MigrationName)) - { - services.AddSingleton( - sp => new SqlMigrationTesterService(database.ConnectionString, databaseData.MigrationName) - ); - } - } - else - { - services.SetupEntityFramework(database.ConnectionString, database.Type); - services.AddPasswordManagerEFRepositories(databaseData.SelfHosted); - services.AddSingleton(); - - if (!string.IsNullOrEmpty(databaseData.MigrationName)) - { - services.AddSingleton(sp => - { - var dbContext = sp.GetRequiredService(); - return new EfMigrationTesterService(dbContext, database.Type, databaseData.MigrationName); - }); - } - } - - return services.BuildServiceProvider(); - } } diff --git a/test/Infrastructure.IntegrationTest/Utilities/DbSiloAttribute.cs b/test/Infrastructure.IntegrationTest/Utilities/DbSiloAttribute.cs new file mode 100644 index 0000000000..ce97c9edb2 --- /dev/null +++ b/test/Infrastructure.IntegrationTest/Utilities/DbSiloAttribute.cs @@ -0,0 +1,68 @@ +using Bit.Core.Enums; +using Microsoft.Data.SqlClient; +using Microsoft.Data.Sqlite; +using MySqlConnector; +using Npgsql; + +namespace Bit.Infrastructure.IntegrationTest.Utilities; + +public class DbSiloAttribute : TestCustomizerAttribute +{ + public string DatabaseName { get; } + + public DbSiloAttribute(string databaseName) + { + DatabaseName = databaseName; + } + + public override Task CustomizeAsync(CustomizationContext customizationContext) + { + var database = customizationContext.Database; + if (!database.Enabled || string.IsNullOrEmpty(database.ConnectionString)) + { + // Nothing to customize + return Task.CompletedTask; + } + + if (database.Type == SupportedDatabaseProviders.MySql) + { + var connectionStringBuilder = new MySqlConnectionStringBuilder(database.ConnectionString) + { + Database = DatabaseName + }; + + database.ConnectionString = connectionStringBuilder.ConnectionString; + } + else if(database.Type == SupportedDatabaseProviders.Postgres) + { + var connectionStringBuilder = new NpgsqlConnectionStringBuilder(database.ConnectionString) + { + Database = DatabaseName + }; + + database.ConnectionString = connectionStringBuilder.ConnectionString; + } + else if (database.Type == SupportedDatabaseProviders.Sqlite) + { + var connectionStringBuilder = new SqliteConnectionStringBuilder(database.ConnectionString); + + var existingFileInfo = new FileInfo(connectionStringBuilder.DataSource); + + // Should we require that the existing file actually exists? + + var newFileInfo = new FileInfo(Path.Join(existingFileInfo.DirectoryName, $"{DatabaseName}.{existingFileInfo.Extension}")); + + connectionStringBuilder.DataSource = newFileInfo.FullName; + database.ConnectionString = connectionStringBuilder.ConnectionString; + } + else + { + var connectionStringBuilder = new SqlConnectionStringBuilder(database.ConnectionString) + { + DataSource = DatabaseName + }; + database.ConnectionString = connectionStringBuilder.ConnectionString; + } + return Task.CompletedTask; + } +} diff --git a/test/Infrastructure.IntegrationTest/Utilities/DefaultCustomizerAttribute.cs b/test/Infrastructure.IntegrationTest/Utilities/DefaultCustomizerAttribute.cs new file mode 100644 index 0000000000..895fa65aa5 --- /dev/null +++ b/test/Infrastructure.IntegrationTest/Utilities/DefaultCustomizerAttribute.cs @@ -0,0 +1,68 @@ + +using Bit.Core.Enums; +using Bit.Core.Settings; +using Bit.Infrastructure.Dapper; +using Bit.Infrastructure.EntityFramework; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using MySqlConnector; + +namespace Bit.Infrastructure.IntegrationTest.Utilities; + +/// +/// The default customization applied to all database tests. If no customizer is added, this is added implicitly. +/// +public class DefaultCustomizerAttribute : TestCustomizerAttribute +{ + public static readonly DefaultCustomizerAttribute Instance = new(); + + public override Task CustomizeAsync(CustomizationContext customizationContext) + { + var database = customizationContext.Database; + var services = customizationContext.Services; + if (!database.Enabled) + { + // Do nothing + return Task.CompletedTask; + } + + services.AddLogging(builder => + { + builder.AddProvider(new XunitLoggerProvider(LogLevel.Information)); + }); + + services.AddDataProtection(); + + services.AddSingleton(customizationContext.Database); + + if (database.Type == SupportedDatabaseProviders.SqlServer && !database.UseEf) + { + services.AddDapperRepositories(false); + var globalSettings = new GlobalSettings + { + DatabaseProvider = "sqlServer", + SqlServer = new GlobalSettings.SqlSettings + { + ConnectionString = database.ConnectionString, + }, + }; + services.AddSingleton(globalSettings); + services.AddSingleton(globalSettings); + services.AddDistributedSqlServerCache((options) => + { + options.ConnectionString = database.ConnectionString; + options.SchemaName = "dbo"; + options.TableName = "Cache"; + }); + } + else + { + services.SetupEntityFramework(database.ConnectionString, database.Type); + services.AddPasswordManagerEFRepositories(false); + services.AddSingleton(); + } + + return Task.CompletedTask; + } +} diff --git a/test/Infrastructure.IntegrationTest/Utilities/TestCustomizerAttribute.cs b/test/Infrastructure.IntegrationTest/Utilities/TestCustomizerAttribute.cs new file mode 100644 index 0000000000..acd086d109 --- /dev/null +++ b/test/Infrastructure.IntegrationTest/Utilities/TestCustomizerAttribute.cs @@ -0,0 +1,40 @@ +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; +using Xunit.Sdk; + +namespace Bit.Infrastructure.IntegrationTest.Utilities; + +public class CustomizationContext +{ + // Defaults to Database.Enabled if left as null + public bool? Enabled { get; set; } + public Database Database { get; } + + public MethodInfo TestMethod { get; } + + public DisposalTracker DisposalTracker { get; } + + public IServiceCollection Services { get; } + + public Func ParameterResolver { get; set; } = DefaultParameterResolver; + + + public CustomizationContext(Database database, MethodInfo testMethod, DisposalTracker disposalTracker) + { + Database = database; + TestMethod = testMethod; + DisposalTracker = disposalTracker; + Services = new ServiceCollection(); + } + + private static object? DefaultParameterResolver(ServiceProvider services, ParameterInfo parameter) + { + return services.GetService(parameter.ParameterType); + } +} + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method | AttributeTargets.Assembly)] +public abstract class TestCustomizerAttribute : Attribute +{ + public abstract Task CustomizeAsync(CustomizationContext customizationContext); +}