diff --git a/src/Core/Domains/User.cs b/src/Core/Domains/User.cs index a1d0147d82..6c5bf6b7ab 100644 --- a/src/Core/Domains/User.cs +++ b/src/Core/Domains/User.cs @@ -9,7 +9,7 @@ namespace Bit.Core.Domains internal static string TypeValue = "user"; [JsonProperty("id")] - public string Id { get; set; } = Guid.NewGuid().ToString(); + public string Id { get; set; } [JsonProperty("type")] public string Type { get; private set; } = TypeValue; diff --git a/src/Core/Repositories/SqlServer/BaseRepository.cs b/src/Core/Repositories/SqlServer/BaseRepository.cs new file mode 100644 index 0000000000..7ad8feba11 --- /dev/null +++ b/src/Core/Repositories/SqlServer/BaseRepository.cs @@ -0,0 +1,52 @@ +using System; + +namespace Bit.Core.Repositories.SqlServer +{ + public abstract class BaseRepository + { + private static readonly long _baseDateTicks = new DateTime(1900, 1, 1).Ticks; + + public BaseRepository(string connectionString) + { + if(string.IsNullOrWhiteSpace(connectionString)) + { + throw new ArgumentNullException(nameof(connectionString)); + } + + ConnectionString = connectionString; + } + + protected string ConnectionString { get; private set; } + + /// + /// Generate sequential Guid for Sql Server. + /// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs + /// + /// A comb Guid. + protected Guid GenerateComb() + { + var guidArray = Guid.NewGuid().ToByteArray(); + + var now = DateTime.UtcNow; + + // Get the days and milliseconds which will be used to build the byte string + var days = new TimeSpan(now.Ticks - _baseDateTicks); + var msecs = now.TimeOfDay; + + // Convert to a byte array + // Note that SQL Server is accurate to 1/300th of a millisecond so we divide by 3.333333 + var daysArray = BitConverter.GetBytes(days.Days); + var msecsArray = BitConverter.GetBytes((long)(msecs.TotalMilliseconds / 3.333333)); + + // Reverse the bytes to match SQL Servers ordering + Array.Reverse(daysArray); + Array.Reverse(msecsArray); + + // Copy the bytes into the guid + Array.Copy(daysArray, daysArray.Length - 2, guidArray, guidArray.Length - 6, 2); + Array.Copy(msecsArray, msecsArray.Length - 4, guidArray, guidArray.Length - 4, 4); + + return new Guid(guidArray); + } + } +} diff --git a/src/Core/Repositories/SqlServer/CipherRepository.cs b/src/Core/Repositories/SqlServer/CipherRepository.cs index 516106feb0..42fc92166d 100644 --- a/src/Core/Repositories/SqlServer/CipherRepository.cs +++ b/src/Core/Repositories/SqlServer/CipherRepository.cs @@ -1,27 +1,193 @@ using System; +using System.Linq; using System.Collections.Generic; +using System.Data.SqlClient; using System.Threading.Tasks; +using Bit.Core.Repositories.SqlServer.Models; +using DataTableProxy; +using Bit.Core.Domains; namespace Bit.Core.Repositories.SqlServer { - public class CipherRepository : ICipherRepository + public class CipherRepository : BaseRepository, ICipherRepository { public CipherRepository(string connectionString) + : base(connectionString) { } public Task DirtyCiphersAsync(string userId) { - throw new NotImplementedException(); + return Task.FromResult(0); } public Task UpdateDirtyCiphersAsync(IEnumerable ciphers) { - throw new NotImplementedException(); + var cleanedCiphers = ciphers.Where(c => c is Cipher); + if(cleanedCiphers.Count() == 0) + { + return Task.FromResult(0); + } + + using(var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using(var transaction = connection.BeginTransaction()) + { + try + { + // 1. Create temp tables to bulk copy into. + + var sqlCreateTemp = @" + SELECT TOP 0 * + INTO #TempFolder + FROM [dbo].[Folder] + + SELECT TOP 0 * + INTO #TempSite + FROM [dbo].[Site]"; + + using(var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) + { + cmd.ExecuteNonQuery(); + } + + // 2. Bulk bopy into temp tables. + + using(var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "#TempFolder"; + + var dataTable = cleanedCiphers + .Where(c => c is Folder) + .Select(c => new FolderTableModel(c as Folder)) + .ToTable(new ClassMapping().AddAllPropertiesAsColumns()); + + bulkCopy.WriteToServer(dataTable); + } + + using(var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "#TempSite"; + + var dataTable = cleanedCiphers + .Where(c => c is Site) + .Select(c => new SiteTableModel(c as Site)) + .ToTable(new ClassMapping().AddAllPropertiesAsColumns()); + + bulkCopy.WriteToServer(dataTable); + } + + // 3. Insert into real tables from temp tables and clean up. + + var sqlUpdate = @" + UPDATE + [dbo].[Folder] + SET + [UserId] = TF.[UserId], + [Name] = TF.[Name], + [CreationDate] = TF.[CreationDate], + [RevisionDate] = TF.[RevisionDate] + FROM + [dbo].[Folder] F + INNER JOIN + #TempFolder TF ON F.Id = TF.Id + + UPDATE + [dbo].[Site] + SET + [UserId] = TS.[UserId], + [FolderId] = TS.[FolderId], + [Name] = TS.[Name], + [Uri] = TS.[Uri], + [Username] = TS.[Username], + [Password] = TS.[Password], + [Notes] = TS.[Notes], + [CreationDate] = TS.[CreationDate], + [RevisionDate] = TS.[RevisionDate] + FROM + [dbo].[Site] S + INNER JOIN + #TempSite TS ON S.Id = TS.Id + + DROP TABLE #TempFolder + DROP TABLE #TempSite"; + + using(var cmd = new SqlCommand(sqlUpdate, connection, transaction)) + { + cmd.ExecuteNonQuery(); + } + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } + + return Task.FromResult(0); } public Task CreateAsync(IEnumerable ciphers) { - throw new NotImplementedException(); + var cleanedCiphers = ciphers.Where(c => c is Cipher); + if(cleanedCiphers.Count() == 0) + { + return Task.FromResult(0); + } + + // Generate new Ids for these new ciphers + foreach(var cipher in cleanedCiphers) + { + cipher.Id = GenerateComb().ToString(); + } + + using(var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using(var transaction = connection.BeginTransaction()) + { + try + { + using(var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[Folder]"; + + var dataTable = cleanedCiphers + .Where(c => c is Folder) + .Select(c => new FolderTableModel(c as Folder)) + .ToTable(new ClassMapping().AddAllPropertiesAsColumns()); + + bulkCopy.WriteToServer(dataTable); + } + + using(var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[Site]"; + + var dataTable = cleanedCiphers + .Where(c => c is Site) + .Select(c => new SiteTableModel(c as Site)) + .ToTable(new ClassMapping().AddAllPropertiesAsColumns()); + + bulkCopy.WriteToServer(dataTable); + } + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } + + return Task.FromResult(0); } } } diff --git a/src/Core/Repositories/SqlServer/Repository.cs b/src/Core/Repositories/SqlServer/Repository.cs index 47cb4a574c..d08ec78462 100644 --- a/src/Core/Repositories/SqlServer/Repository.cs +++ b/src/Core/Repositories/SqlServer/Repository.cs @@ -8,19 +8,11 @@ using Dapper; namespace Bit.Core.Repositories.SqlServer { - public abstract class Repository : IRepository where T : IDataObject where TModel : ITableModel + public abstract class Repository : BaseRepository, IRepository where T : IDataObject where TModel : ITableModel { - private static readonly long _baseDateTicks = new DateTime(1900, 1, 1).Ticks; - public Repository(string connectionString, string schema = null, string table = null) + : base(connectionString) { - if(string.IsNullOrWhiteSpace(connectionString)) - { - throw new ArgumentNullException(nameof(connectionString)); - } - - ConnectionString = connectionString; - if(!string.IsNullOrWhiteSpace(table)) { Table = table; @@ -32,7 +24,6 @@ namespace Bit.Core.Repositories.SqlServer } } - protected string ConnectionString { get; private set; } protected string Schema { get; private set; } = "dbo"; protected string Table { get; private set; } = typeof(T).Name; @@ -109,36 +100,5 @@ namespace Bit.Core.Repositories.SqlServer commandType: CommandType.StoredProcedure); } } - - /// - /// Generate sequential Guid for Sql Server. - /// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs - /// - /// A comb Guid. - protected Guid GenerateComb() - { - var guidArray = Guid.NewGuid().ToByteArray(); - - var now = DateTime.UtcNow; - - // Get the days and milliseconds which will be used to build the byte string - var days = new TimeSpan(now.Ticks - _baseDateTicks); - var msecs = now.TimeOfDay; - - // Convert to a byte array - // Note that SQL Server is accurate to 1/300th of a millisecond so we divide by 3.333333 - var daysArray = BitConverter.GetBytes(days.Days); - var msecsArray = BitConverter.GetBytes((long)(msecs.TotalMilliseconds / 3.333333)); - - // Reverse the bytes to match SQL Servers ordering - Array.Reverse(daysArray); - Array.Reverse(msecsArray); - - // Copy the bytes into the guid - Array.Copy(daysArray, daysArray.Length - 2, guidArray, guidArray.Length - 6, 2); - Array.Copy(msecsArray, msecsArray.Length - 4, guidArray, guidArray.Length - 4, 4); - - return new Guid(guidArray); - } } } diff --git a/src/Core/project.json b/src/Core/project.json index 76ffe0d6df..fc3d946a0d 100644 --- a/src/Core/project.json +++ b/src/Core/project.json @@ -15,7 +15,8 @@ "Microsoft.AspNet.DataProtection.Extensions": "1.0.0-rc1-final", "Microsoft.Azure.DocumentDB": "1.5.2", "Newtonsoft.Json": "8.0.1", - "Dapper": "1.42.0" + "Dapper": "1.42.0", + "DataTableProxy": "1.2.0" }, "frameworks": {