1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-09 23:28:12 -05:00

Prepare for send direct upload (#1174)

* Add sendId to path

Event Grid returns the blob path, which will be used to grab a Send and verify file size

* Re-validate access upon file download

Increment access count only when file is downloaded. File
name and size are leaked, but this is a good first step toward
solving the access-download race
This commit is contained in:
Matt Gibson 2021-03-01 15:01:04 -06:00 committed by GitHub
parent 13f12aaf58
commit 8d5fc21b51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 119 additions and 34 deletions

View File

@ -61,7 +61,8 @@ namespace Bit.Api.Controllers
} }
var sendResponse = new SendAccessResponseModel(send, _globalSettings); var sendResponse = new SendAccessResponseModel(send, _globalSettings);
if (send.UserId.HasValue) { if (send.UserId.HasValue)
{
var creator = await _userService.GetUserByIdAsync(send.UserId.Value); var creator = await _userService.GetUserByIdAsync(send.UserId.Value);
sendResponse.CreatorIdentifier = creator.Email; sendResponse.CreatorIdentifier = creator.Email;
} }
@ -69,14 +70,40 @@ namespace Bit.Api.Controllers
} }
[AllowAnonymous] [AllowAnonymous]
[HttpGet("access/file/{id}")] [HttpPost("{encodedSendId}/access/file/{fileId}")]
public async Task<SendFileDownloadDataResponseModel> GetSendFileDownloadData(string id) public async Task<IActionResult> GetSendFileDownloadData(string encodedSendId,
string fileId, [FromBody] SendAccessRequestModel model)
{ {
return new SendFileDownloadDataResponseModel() var sendId = new Guid(CoreHelpers.Base64UrlDecode(encodedSendId));
var send = await _sendRepository.GetByIdAsync(sendId);
if (send == null)
{ {
Id = id, throw new BadRequestException("Could not locate send");
Url = await _sendFileStorageService.GetSendFileDownloadUrlAsync(id), }
};
var (url, passwordRequired, passwordInvalid) = await _sendService.GetSendFileDownloadUrlAsync(send, fileId,
model.Password);
if (passwordRequired)
{
return new UnauthorizedResult();
}
if (passwordInvalid)
{
await Task.Delay(2000);
throw new BadRequestException("Invalid password.");
}
if (send == null)
{
throw new NotFoundException();
}
return new ObjectResult(new SendFileDownloadDataResponseModel()
{
Id = fileId,
Url = url,
});
} }
[HttpGet("{id}")] [HttpGet("{id}")]

View File

@ -13,5 +13,6 @@ namespace Bit.Core.Services
Task CreateSendAsync(Send send, SendFileData data, Stream stream, long requestLength); Task CreateSendAsync(Send send, SendFileData data, Stream stream, long requestLength);
Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password); Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password);
string HashPassword(string password); string HashPassword(string password);
Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password);
} }
} }

View File

@ -8,9 +8,9 @@ namespace Bit.Core.Services
public interface ISendFileStorageService public interface ISendFileStorageService
{ {
Task UploadNewFileAsync(Stream stream, Send send, string fileId); Task UploadNewFileAsync(Stream stream, Send send, string fileId);
Task DeleteFileAsync(string fileId); Task DeleteFileAsync(Send send, string fileId);
Task DeleteFilesForOrganizationAsync(Guid organizationId); Task DeleteFilesForOrganizationAsync(Guid organizationId);
Task DeleteFilesForUserAsync(Guid userId); Task DeleteFilesForUserAsync(Guid userId);
Task<string> GetSendFileDownloadUrlAsync(string fileId); Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId);
} }
} }

View File

@ -10,12 +10,14 @@ namespace Bit.Core.Services
{ {
public class AzureSendFileStorageService : ISendFileStorageService public class AzureSendFileStorageService : ISendFileStorageService
{ {
private const string FilesContainerName = "sendfiles"; public const string FilesContainerName = "sendfiles";
private static readonly TimeSpan _downloadLinkLiveTime = TimeSpan.FromMinutes(1); private static readonly TimeSpan _downloadLinkLiveTime = TimeSpan.FromMinutes(1);
private readonly CloudBlobClient _blobClient; private readonly CloudBlobClient _blobClient;
private CloudBlobContainer _sendFilesContainer; private CloudBlobContainer _sendFilesContainer;
public static string SendIdFromBlobName(string blobName) => blobName.Split('/')[0];
public static string BlobName(Send send, string fileId) => $"{send.Id}/{fileId}";
public AzureSendFileStorageService( public AzureSendFileStorageService(
GlobalSettings globalSettings) GlobalSettings globalSettings)
{ {
@ -26,7 +28,7 @@ namespace Bit.Core.Services
public async Task UploadNewFileAsync(Stream stream, Send send, string fileId) public async Task UploadNewFileAsync(Stream stream, Send send, string fileId)
{ {
await InitAsync(); await InitAsync();
var blob = _sendFilesContainer.GetBlockBlobReference(fileId); var blob = _sendFilesContainer.GetBlockBlobReference(BlobName(send, fileId));
if (send.UserId.HasValue) if (send.UserId.HasValue)
{ {
blob.Metadata.Add("userId", send.UserId.Value.ToString()); blob.Metadata.Add("userId", send.UserId.Value.ToString());
@ -39,10 +41,10 @@ namespace Bit.Core.Services
await blob.UploadFromStreamAsync(stream); await blob.UploadFromStreamAsync(stream);
} }
public async Task DeleteFileAsync(string fileId) public async Task DeleteFileAsync(Send send, string fileId)
{ {
await InitAsync(); await InitAsync();
var blob = _sendFilesContainer.GetBlockBlobReference(fileId); var blob = _sendFilesContainer.GetBlockBlobReference(BlobName(send, fileId));
await blob.DeleteIfExistsAsync(); await blob.DeleteIfExistsAsync();
} }
@ -56,14 +58,14 @@ namespace Bit.Core.Services
await InitAsync(); await InitAsync();
} }
public async Task<string> GetSendFileDownloadUrlAsync(string fileId) public async Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId)
{ {
await InitAsync(); await InitAsync();
var blob = _sendFilesContainer.GetBlockBlobReference(fileId); var blob = _sendFilesContainer.GetBlockBlobReference(BlobName(send, fileId));
var accessPolicy = new SharedAccessBlobPolicy() var accessPolicy = new SharedAccessBlobPolicy()
{ {
SharedAccessExpiryTime = DateTime.UtcNow.Add(_downloadLinkLiveTime), SharedAccessExpiryTime = DateTime.UtcNow.Add(_downloadLinkLiveTime),
Permissions = SharedAccessBlobPermissions.Read Permissions = SharedAccessBlobPermissions.Read,
}; };
return blob.Uri + blob.GetSharedAccessSignature(accessPolicy); return blob.Uri + blob.GetSharedAccessSignature(accessPolicy);

View File

@ -3,6 +3,7 @@ using System.IO;
using System; using System;
using Bit.Core.Models.Table; using Bit.Core.Models.Table;
using Bit.Core.Settings; using Bit.Core.Settings;
using System.Linq;
namespace Bit.Core.Services namespace Bit.Core.Services
{ {
@ -11,6 +12,9 @@ namespace Bit.Core.Services
private readonly string _baseDirPath; private readonly string _baseDirPath;
private readonly string _baseSendUrl; private readonly string _baseSendUrl;
private string RelativeFilePath(Send send, string fileID) => $"{send.Id}/{fileID}";
private string FilePath(Send send, string fileID) => $"{_baseDirPath}/{RelativeFilePath(send, fileID)}";
public LocalSendStorageService( public LocalSendStorageService(
GlobalSettings globalSettings) GlobalSettings globalSettings)
{ {
@ -21,17 +25,21 @@ namespace Bit.Core.Services
public async Task UploadNewFileAsync(Stream stream, Send send, string fileId) public async Task UploadNewFileAsync(Stream stream, Send send, string fileId)
{ {
await InitAsync(); await InitAsync();
using (var fs = File.Create($"{_baseDirPath}/{fileId}")) var path = FilePath(send, fileId);
Directory.CreateDirectory(Path.GetDirectoryName(path));
using (var fs = File.Create(path))
{ {
stream.Seek(0, SeekOrigin.Begin); stream.Seek(0, SeekOrigin.Begin);
await stream.CopyToAsync(fs); await stream.CopyToAsync(fs);
} }
} }
public async Task DeleteFileAsync(string fileId) public async Task DeleteFileAsync(Send send, string fileId)
{ {
await InitAsync(); await InitAsync();
DeleteFileIfExists($"{_baseDirPath}/{fileId}"); var path = FilePath(send, fileId);
DeleteFileIfExists(path);
DeleteDirectoryIfExistsAndEmpty(Path.GetDirectoryName(path));
} }
public async Task DeleteFilesForOrganizationAsync(Guid organizationId) public async Task DeleteFilesForOrganizationAsync(Guid organizationId)
@ -44,10 +52,10 @@ namespace Bit.Core.Services
await InitAsync(); await InitAsync();
} }
public async Task<string> GetSendFileDownloadUrlAsync(string fileId) public async Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId)
{ {
await InitAsync(); await InitAsync();
return $"{_baseSendUrl}/{fileId}"; return $"{_baseSendUrl}/{RelativeFilePath(send, fileId)}";
} }
private void DeleteFileIfExists(string path) private void DeleteFileIfExists(string path)
@ -58,6 +66,14 @@ namespace Bit.Core.Services
} }
} }
private void DeleteDirectoryIfExistsAndEmpty(string path)
{
if (Directory.Exists(path) && !Directory.EnumerateFiles(path).Any())
{
Directory.Delete(path);
}
}
private Task InitAsync() private Task InitAsync()
{ {
if (!Directory.Exists(_baseDirPath)) if (!Directory.Exists(_baseDirPath))

View File

@ -124,7 +124,6 @@ namespace Bit.Core.Services
} }
var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false);
await _sendFileStorageService.UploadNewFileAsync(stream, send, fileId);
try try
{ {
@ -133,11 +132,12 @@ namespace Bit.Core.Services
send.Data = JsonConvert.SerializeObject(data, send.Data = JsonConvert.SerializeObject(data,
new JsonSerializerSettings { NullValueHandling = NullValueHandling.Ignore }); new JsonSerializerSettings { NullValueHandling = NullValueHandling.Ignore });
await SaveSendAsync(send); await SaveSendAsync(send);
await _sendFileStorageService.UploadNewFileAsync(stream, send, fileId);
} }
catch catch
{ {
// Clean up since this is not transactional // Clean up since this is not transactional
await _sendFileStorageService.DeleteFileAsync(fileId); await _sendFileStorageService.DeleteFileAsync(send, fileId);
throw; throw;
} }
} }
@ -148,27 +148,26 @@ namespace Bit.Core.Services
if (send.Type == Enums.SendType.File) if (send.Type == Enums.SendType.File)
{ {
var data = JsonConvert.DeserializeObject<SendFileData>(send.Data); var data = JsonConvert.DeserializeObject<SendFileData>(send.Data);
await _sendFileStorageService.DeleteFileAsync(data.Id); await _sendFileStorageService.DeleteFileAsync(send, data.Id);
} }
await _pushService.PushSyncSendDeleteAsync(send); await _pushService.PushSyncSendDeleteAsync(send);
} }
// Response: Send, password required, password invalid public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send,
public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password) string password)
{ {
var send = await _sendRepository.GetByIdAsync(sendId);
var now = DateTime.UtcNow; var now = DateTime.UtcNow;
if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount || if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount ||
send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled || send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled ||
send.DeletionDate < now) send.DeletionDate < now)
{ {
return (null, false, false); return (false, false, false);
} }
if (!string.IsNullOrWhiteSpace(send.Password)) if (!string.IsNullOrWhiteSpace(send.Password))
{ {
if (string.IsNullOrWhiteSpace(password)) if (string.IsNullOrWhiteSpace(password))
{ {
return (null, true, false); return (false, true, false);
} }
var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password); var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password);
if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded) if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded)
@ -177,11 +176,51 @@ namespace Bit.Core.Services
} }
if (passwordResult == PasswordVerificationResult.Failed) if (passwordResult == PasswordVerificationResult.Failed)
{ {
return (null, false, true); return (false, false, true);
} }
} }
// TODO: maybe move this to a simple ++ sproc?
return (true, false, false);
}
// Response: Send, password required, password invalid
public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Can only get a download URL for a file type of Send");
}
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
if (!grantAccess)
{
return (null, passwordRequired, passwordInvalid);
}
send.AccessCount++; send.AccessCount++;
await _sendRepository.ReplaceAsync(send);
return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false);
}
// Response: Send, password required, password invalid
public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password)
{
var send = await _sendRepository.GetByIdAsync(sendId);
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
if (!grantAccess)
{
return (null, passwordRequired, passwordInvalid);
}
// TODO: maybe move this to a simple ++ sproc?
if (send.Type != SendType.File)
{
// File sends are incremented during file download
send.AccessCount++;
}
await _sendRepository.ReplaceAsync(send); await _sendRepository.ReplaceAsync(send);
await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed); await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed);
return (send, false, false); return (send, false, false);

View File

@ -12,7 +12,7 @@ namespace Bit.Core.Services
return Task.FromResult(0); return Task.FromResult(0);
} }
public Task DeleteFileAsync(string fileId) public Task DeleteFileAsync(Send send, string fileId)
{ {
return Task.FromResult(0); return Task.FromResult(0);
} }
@ -27,7 +27,7 @@ namespace Bit.Core.Services
return Task.FromResult(0); return Task.FromResult(0);
} }
public Task<string> GetSendFileDownloadUrlAsync(string fileId) public Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId)
{ {
return Task.FromResult((string)null); return Task.FromResult((string)null);
} }