diff --git a/src/Core/GlobalSettings.cs b/src/Core/GlobalSettings.cs index 1d7cd52bb8..9412e2ee99 100644 --- a/src/Core/GlobalSettings.cs +++ b/src/Core/GlobalSettings.cs @@ -15,7 +15,8 @@ namespace Bit.Core public virtual bool DisableUserRegistration { get; set; } public virtual InstallationSettings Installation { get; set; } = new InstallationSettings(); public virtual BaseServiceUriSettings BaseServiceUri { get; set; } = new BaseServiceUriSettings(); - public virtual SqlServerSettings SqlServer { get; set; } = new SqlServerSettings(); + public virtual SqlSettings SqlServer { get; set; } = new SqlSettings(); + public virtual SqlSettings PostgreSql { get; set; } = new SqlSettings(); public virtual MailSettings Mail { get; set; } = new MailSettings(); public virtual StorageSettings Storage { get; set; } = new StorageSettings(); public virtual StorageSettings Events { get; set; } = new StorageSettings(); @@ -45,7 +46,7 @@ namespace Bit.Core public string InternalVault { get; set; } } - public class SqlServerSettings + public class SqlSettings { private string _connectionString; private string _readOnlyConnectionString; diff --git a/src/Core/Repositories/PostgreSql/Repository.cs b/src/Core/Repositories/PostgreSql/Repository.cs index 789f3810f4..9515fc201d 100644 --- a/src/Core/Repositories/PostgreSql/Repository.cs +++ b/src/Core/Repositories/PostgreSql/Repository.cs @@ -12,7 +12,7 @@ namespace Bit.Core.Repositories.PostgreSql where TId : IEquatable where T : class, ITableObject { - public Repository(string connectionString, string readOnlyConnectionString, string table) + public Repository(string connectionString, string readOnlyConnectionString, string table = null) : base(connectionString, readOnlyConnectionString) { if(!string.IsNullOrWhiteSpace(table)) @@ -21,7 +21,7 @@ namespace Bit.Core.Repositories.PostgreSql } else { - Table = SnakeCase(typeof(T).Name); + Table = SnakeCase(typeof(T).Name).ToLowerInvariant(); } } diff --git a/src/Core/Repositories/PostgreSql/UserRepository.cs b/src/Core/Repositories/PostgreSql/UserRepository.cs new file mode 100644 index 0000000000..cc9cfc14ea --- /dev/null +++ b/src/Core/Repositories/PostgreSql/UserRepository.cs @@ -0,0 +1,159 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Threading.Tasks; +using Bit.Core.Models.Data; +using Bit.Core.Models.Table; +using Dapper; +using Npgsql; + +namespace Bit.Core.Repositories.PostgreSql +{ + public class UserRepository : Repository, IUserRepository + { + public UserRepository(GlobalSettings globalSettings) + : this(globalSettings.PostgreSql.ConnectionString, globalSettings.PostgreSql.ReadOnlyConnectionString) + { } + + public UserRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public override async Task GetByIdAsync(Guid id) + { + return await base.GetByIdAsync(id); + } + + public async Task GetByEmailAsync(string email) + { + using(var connection = new NpgsqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "user_read_by_email", + new { email = email }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task GetKdfInformationByEmailAsync(string email) + { + using(var connection = new NpgsqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "user_read_kdf_by_email", + new { email = email }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task> SearchAsync(string email, int skip, int take) + { + using(var connection = new NpgsqlConnection(ReadOnlyConnectionString)) + { + var results = await connection.QueryAsync( + "user_search", + new { email = email, skip = skip, take = take }, + commandType: CommandType.StoredProcedure, + commandTimeout: 120); + + return results.ToList(); + } + } + + public async Task> GetManyByPremiumAsync(bool premium) + { + using(var connection = new NpgsqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "user_read_by_premium", + new { premium = premium }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task> GetManyByPremiumRenewalAsync() + { + using(var connection = new NpgsqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "user_read_by_premium_renewal", + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task GetPublicKeyAsync(Guid id) + { + using(var connection = new NpgsqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "user_read_public_key_by_id", + new { id = id }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task GetAccountRevisionDateAsync(Guid id) + { + using(var connection = new NpgsqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "user_read_account_revision_date_by_id", + new { id = id }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public override async Task ReplaceAsync(User user) + { + await base.ReplaceAsync(user); + } + + public override async Task DeleteAsync(User user) + { + using(var connection = new NpgsqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + $"user_delete_by_id", + new { id = user.Id }, + commandType: CommandType.StoredProcedure, + commandTimeout: 180); + } + } + + public async Task UpdateStorageAsync(Guid id) + { + using(var connection = new NpgsqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + "user_update_storage", + new { id = id }, + commandType: CommandType.StoredProcedure, + commandTimeout: 180); + } + } + + public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate) + { + using(var connection = new NpgsqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + "user_update_renewal_reminder_date", + new { id = id, renewal_reminder_date = renewalReminderDate }, + commandType: CommandType.StoredProcedure); + } + } + } +} diff --git a/src/Core/Utilities/ServiceCollectionExtensions.cs b/src/Core/Utilities/ServiceCollectionExtensions.cs index 3f8e044155..ab4906c6cc 100644 --- a/src/Core/Utilities/ServiceCollectionExtensions.cs +++ b/src/Core/Utilities/ServiceCollectionExtensions.cs @@ -19,6 +19,7 @@ using Microsoft.WindowsAzure.Storage; using System; using System.IO; using SqlServerRepos = Bit.Core.Repositories.SqlServer; +using PostgreSqlRepos = Bit.Core.Repositories.PostgreSql; using System.Threading.Tasks; using TableStorageRepos = Bit.Core.Repositories.TableStorage; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -32,19 +33,26 @@ namespace Bit.Core.Utilities { public static void AddSqlServerRepositories(this IServiceCollection services, GlobalSettings globalSettings) { - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); + if(!string.IsNullOrWhiteSpace(globalSettings.PostgreSql?.ConnectionString)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + } if(globalSettings.SelfHosted) {