diff --git a/src/Api/Controllers/AccountsController.cs b/src/Api/Controllers/AccountsController.cs index dac1a70b59..0ec437b114 100644 --- a/src/Api/Controllers/AccountsController.cs +++ b/src/Api/Controllers/AccountsController.cs @@ -33,6 +33,8 @@ namespace Bit.Api.Controllers private readonly IPaymentService _paymentService; private readonly IUserRepository _userRepository; private readonly IUserService _userService; + private readonly ISendRepository _sendRepository; + private readonly ISendService _sendService; public AccountsController( GlobalSettings globalSettings, @@ -43,7 +45,9 @@ namespace Bit.Api.Controllers IPaymentService paymentService, ISsoUserRepository ssoUserRepository, IUserRepository userRepository, - IUserService userService) + IUserService userService, + ISendRepository sendRepository, + ISendService sendService) { _cipherRepository = cipherRepository; _folderRepository = folderRepository; @@ -53,6 +57,8 @@ namespace Bit.Api.Controllers _paymentService = paymentService; _userRepository = userRepository; _userService = userService; + _sendRepository = sendRepository; + _sendService = sendService; } [HttpPost("prelogin")] @@ -301,13 +307,28 @@ namespace Bit.Api.Controllers } } + var sends = new List(); + if (model.Sends?.Count() > 0) + { + var existingSends = await _sendRepository.GetManyByUserIdAsync(user.Id); + var sendsDict = model.Sends?.ToDictionary(s => s.Id); + if (existingSends.Any() && sendsDict != null) + { + foreach (var send in existingSends.Where(s => sendsDict.ContainsKey(s.Id))) + { + sends.Add(sendsDict[send.Id].ToSend(send, _sendService)); + } + } + } + var result = await _userService.UpdateKeyAsync( user, model.MasterPasswordHash, model.Key, model.PrivateKey, ciphers, - folders); + folders, + sends); if (result.Succeeded) { diff --git a/src/Core/Models/Api/Request/Accounts/UpdateKeyRequestModel.cs b/src/Core/Models/Api/Request/Accounts/UpdateKeyRequestModel.cs index d895498a16..45468cfe60 100644 --- a/src/Core/Models/Api/Request/Accounts/UpdateKeyRequestModel.cs +++ b/src/Core/Models/Api/Request/Accounts/UpdateKeyRequestModel.cs @@ -12,6 +12,7 @@ namespace Bit.Core.Models.Api public IEnumerable Ciphers { get; set; } [Required] public IEnumerable Folders { get; set; } + public IEnumerable Sends { get; set; } [Required] public string PrivateKey { get; set; } [Required] diff --git a/src/Core/Models/Api/Request/SendRequestModel.cs b/src/Core/Models/Api/Request/SendRequestModel.cs index 85fe417b57..a1b2ecb69a 100644 --- a/src/Core/Models/Api/Request/SendRequestModel.cs +++ b/src/Core/Models/Api/Request/SendRequestModel.cs @@ -130,4 +130,10 @@ namespace Bit.Core.Models.Api return existingSend; } } + + public class SendWithIdRequestModel : SendRequestModel + { + [Required] + public Guid? Id { get; set; } + } } diff --git a/src/Core/Repositories/ICipherRepository.cs b/src/Core/Repositories/ICipherRepository.cs index aac2cae0dc..5a5e1ede76 100644 --- a/src/Core/Repositories/ICipherRepository.cs +++ b/src/Core/Repositories/ICipherRepository.cs @@ -28,7 +28,7 @@ namespace Bit.Core.Repositories Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId); Task DeleteByUserIdAsync(Guid userId); Task DeleteByOrganizationIdAsync(Guid organizationId); - Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders); + Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends); Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers); Task CreateAsync(IEnumerable ciphers, IEnumerable folders); Task CreateAsync(IEnumerable ciphers, IEnumerable collections, diff --git a/src/Core/Repositories/SqlServer/CipherRepository.cs b/src/Core/Repositories/SqlServer/CipherRepository.cs index 16c3adffc4..faf626d768 100644 --- a/src/Core/Repositories/SqlServer/CipherRepository.cs +++ b/src/Core/Repositories/SqlServer/CipherRepository.cs @@ -282,7 +282,7 @@ namespace Bit.Core.Repositories.SqlServer } } - public Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders) + public Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends) { using (var connection = new SqlConnection(ConnectionString)) { @@ -323,7 +323,11 @@ namespace Bit.Core.Repositories.SqlServer SELECT TOP 0 * INTO #TempFolder - FROM [dbo].[Folder]"; + FROM [dbo].[Folder] + + SELECT TOP 0 * + INTO #TempSend + FROM [dbo].[Send]"; using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) { @@ -352,6 +356,16 @@ namespace Bit.Core.Repositories.SqlServer } } + if (sends.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "#TempSend"; + var dataTable = BuildSendsTable(bulkCopy, sends); + bulkCopy.WriteToServer(dataTable); + } + } + // 4. Insert into real tables from temp tables and clean up. var sql = string.Empty; @@ -389,9 +403,26 @@ namespace Bit.Core.Repositories.SqlServer F.[UserId] = @UserId"; } + if (sends.Any()) + { + sql += @" + UPDATE + [dbo].[Send] + SET + [Key] = TS.[Key], + [RevisionDate] = TS.[RevisionDate] + FROM + [dbo].[Send] S + INNER JOIN + #TempSend TS ON S.Id = TS.Id + WHERE + S.[UserId] = @UserId"; + } + sql += @" DROP TABLE #TempCipher - DROP TABLE #TempFolder"; + DROP TABLE #TempFolder + DROP TABLE #TempSend"; using (var cmd = new SqlCommand(sql, connection, transaction)) { @@ -833,6 +864,82 @@ namespace Bit.Core.Repositories.SqlServer return collectionCiphersTable; } + private DataTable BuildSendsTable(SqlBulkCopy bulkCopy, IEnumerable sends) + { + var s = sends.FirstOrDefault(); + if (s == null) + { + throw new ApplicationException("Must have some ciphers to bulk import."); + } + + var sendsTable = new DataTable("SendsDataTable"); + + var idColumn = new DataColumn(nameof(s.Id), s.Id.GetType()); + sendsTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(s.UserId), typeof(Guid)); + sendsTable.Columns.Add(userIdColumn); + var organizationIdColumn = new DataColumn(nameof(s.OrganizationId), typeof(Guid)); + sendsTable.Columns.Add(organizationIdColumn); + var typeColumn = new DataColumn(nameof(s.Type), s.Type.GetType()); + sendsTable.Columns.Add(typeColumn); + var dataColumn = new DataColumn(nameof(s.Data), s.Data.GetType()); + sendsTable.Columns.Add(dataColumn); + var keyColumn = new DataColumn(nameof(s.Key), s.Key.GetType()); + sendsTable.Columns.Add(keyColumn); + var passwordColumn = new DataColumn(nameof(s.Password), typeof(string)); + sendsTable.Columns.Add(passwordColumn); + var maxAccessCountColumn = new DataColumn(nameof(s.MaxAccessCount), typeof(int)); + sendsTable.Columns.Add(maxAccessCountColumn); + var accessCountColumn = new DataColumn(nameof(s.AccessCount), s.AccessCount.GetType()); + sendsTable.Columns.Add(accessCountColumn); + var creationDateColumn = new DataColumn(nameof(s.CreationDate), s.CreationDate.GetType()); + sendsTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(s.RevisionDate), s.RevisionDate.GetType()); + sendsTable.Columns.Add(revisionDateColumn); + var expirationDateColumn = new DataColumn(nameof(s.ExpirationDate), typeof(DateTime)); + sendsTable.Columns.Add(expirationDateColumn); + var deletionDateColumn = new DataColumn(nameof(s.DeletionDate), s.DeletionDate.GetType()); + sendsTable.Columns.Add(deletionDateColumn); + var disabledColumn = new DataColumn(nameof(s.Disabled), s.Disabled.GetType()); + sendsTable.Columns.Add(disabledColumn); + var hideEmailColumn = new DataColumn(nameof(s.HideEmail), typeof(bool)); + sendsTable.Columns.Add(hideEmailColumn); + + foreach (DataColumn col in sendsTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + sendsTable.PrimaryKey = keys; + + foreach (var send in sends) + { + var row = sendsTable.NewRow(); + + row[idColumn] = send.Id; + row[userIdColumn] = send.UserId.HasValue ? (object)send.UserId.Value : DBNull.Value; + row[organizationIdColumn] = send.OrganizationId.HasValue ? (object)send.OrganizationId.Value : DBNull.Value; + row[typeColumn] = (short)send.Type; + row[dataColumn] = send.Data; + row[keyColumn] = send.Key; + row[passwordColumn] = send.Password; + row[maxAccessCountColumn] = send.MaxAccessCount.HasValue ? (object)send.MaxAccessCount : DBNull.Value; + row[accessCountColumn] = send.AccessCount; + row[creationDateColumn] = send.CreationDate; + row[revisionDateColumn] = send.RevisionDate; + row[expirationDateColumn] = send.ExpirationDate.HasValue ? (object)send.ExpirationDate : DBNull.Value; + row[deletionDateColumn] = send.DeletionDate; + row[disabledColumn] = send.Disabled; + row[hideEmailColumn] = send.HideEmail.HasValue ? (object)send.HideEmail : DBNull.Value; + + sendsTable.Rows.Add(row); + } + + return sendsTable; + } + public class CipherDetailsWithCollections : CipherDetails { public DataTable CollectionIds { get; set; } diff --git a/src/Core/Services/IUserService.cs b/src/Core/Services/IUserService.cs index 8fdafbb6fc..9cf5f7450d 100644 --- a/src/Core/Services/IUserService.cs +++ b/src/Core/Services/IUserService.cs @@ -38,7 +38,7 @@ namespace Bit.Core.Services Task ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, string key, KdfType kdf, int kdfIterations); Task UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, - IEnumerable ciphers, IEnumerable folders); + IEnumerable ciphers, IEnumerable folders, IEnumerable sends); Task RefreshSecurityStampAsync(User user, string masterPasswordHash); Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true); Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type, diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index 54deb5756d..0222715489 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -52,6 +52,7 @@ namespace Bit.Core.Services private readonly ICurrentContext _currentContext; private readonly GlobalSettings _globalSettings; private readonly IOrganizationService _organizationService; + private readonly ISendRepository _sendRepository; public UserService( IUserRepository userRepository, @@ -79,7 +80,8 @@ namespace Bit.Core.Services IFido2 fido2, ICurrentContext currentContext, GlobalSettings globalSettings, - IOrganizationService organizationService) + IOrganizationService organizationService, + ISendRepository sendRepository) : base( store, optionsAccessor, @@ -113,6 +115,7 @@ namespace Bit.Core.Services _currentContext = currentContext; _globalSettings = globalSettings; _organizationService = organizationService; + _sendRepository = sendRepository; } public Guid? GetProperUserId(ClaimsPrincipal principal) @@ -726,7 +729,7 @@ namespace Bit.Core.Services } public async Task UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, - IEnumerable ciphers, IEnumerable folders) + IEnumerable ciphers, IEnumerable folders, IEnumerable sends) { if (user == null) { @@ -739,9 +742,9 @@ namespace Bit.Core.Services user.SecurityStamp = Guid.NewGuid().ToString(); user.Key = key; user.PrivateKey = privateKey; - if (ciphers.Any() || folders.Any()) + if (ciphers.Any() || folders.Any() || sends.Any()) { - await _cipherRepository.UpdateUserKeysAndCiphersAsync(user, ciphers, folders); + await _cipherRepository.UpdateUserKeysAndCiphersAsync(user, ciphers, folders, sends); } else {