1
0
mirror of https://github.com/bitwarden/server.git synced 2025-07-01 16:12:49 -05:00

Revert filescoped (#2227)

* Revert "Add git blame entry (#2226)"

This reverts commit 239286737d.

* Revert "Turn on file scoped namespaces (#2225)"

This reverts commit 34fb4cca2a.
This commit is contained in:
Justin Baur
2022-08-29 15:53:48 -04:00
committed by GitHub
parent 239286737d
commit bae03feffe
1208 changed files with 74317 additions and 73126 deletions

View File

@ -7,136 +7,137 @@ using Bit.Core.Utilities;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class AmazonSesMailDeliveryService : IMailDeliveryService, IDisposable
namespace Bit.Core.Services
{
private readonly GlobalSettings _globalSettings;
private readonly IWebHostEnvironment _hostingEnvironment;
private readonly ILogger<AmazonSesMailDeliveryService> _logger;
private readonly IAmazonSimpleEmailService _client;
private readonly string _source;
private readonly string _senderTag;
private readonly string _configSetName;
public AmazonSesMailDeliveryService(
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
ILogger<AmazonSesMailDeliveryService> logger)
: this(globalSettings, hostingEnvironment, logger,
new AmazonSimpleEmailServiceClient(
globalSettings.Amazon.AccessKeyId,
globalSettings.Amazon.AccessKeySecret,
RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region))
)
public class AmazonSesMailDeliveryService : IMailDeliveryService, IDisposable
{
}
private readonly GlobalSettings _globalSettings;
private readonly IWebHostEnvironment _hostingEnvironment;
private readonly ILogger<AmazonSesMailDeliveryService> _logger;
private readonly IAmazonSimpleEmailService _client;
private readonly string _source;
private readonly string _senderTag;
private readonly string _configSetName;
public AmazonSesMailDeliveryService(
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
ILogger<AmazonSesMailDeliveryService> logger,
IAmazonSimpleEmailService amazonSimpleEmailService)
{
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId))
public AmazonSesMailDeliveryService(
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
ILogger<AmazonSesMailDeliveryService> logger)
: this(globalSettings, hostingEnvironment, logger,
new AmazonSimpleEmailServiceClient(
globalSettings.Amazon.AccessKeyId,
globalSettings.Amazon.AccessKeySecret,
RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region))
)
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId));
}
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret))
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret));
}
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region))
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.Region));
}
var replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail);
_globalSettings = globalSettings;
_hostingEnvironment = hostingEnvironment;
_logger = logger;
_client = amazonSimpleEmailService;
_source = $"\"{globalSettings.SiteName}\" <{replyToEmail}>";
_senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}";
if (!string.IsNullOrWhiteSpace(_globalSettings.Mail.AmazonConfigSetName))
public AmazonSesMailDeliveryService(
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
ILogger<AmazonSesMailDeliveryService> logger,
IAmazonSimpleEmailService amazonSimpleEmailService)
{
_configSetName = _globalSettings.Mail.AmazonConfigSetName;
}
}
public void Dispose()
{
_client?.Dispose();
}
public async Task SendEmailAsync(MailMessage message)
{
var request = new SendEmailRequest
{
ConfigurationSetName = _configSetName,
Source = _source,
Destination = new Destination
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId))
{
ToAddresses = message.ToEmails
.Select(email => CoreHelpers.PunyEncode(email))
.ToList()
},
Message = new Message
{
Subject = new Content(message.Subject),
Body = new Body
{
Html = new Content
{
Charset = "UTF-8",
Data = message.HtmlContent
},
Text = new Content
{
Charset = "UTF-8",
Data = message.TextContent
}
}
},
Tags = new List<MessageTag>
{
new MessageTag { Name = "Environment", Value = _hostingEnvironment.EnvironmentName },
new MessageTag { Name = "Sender", Value = _senderTag }
throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId));
}
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret))
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret));
}
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region))
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.Region));
}
};
if (message.BccEmails?.Any() ?? false)
{
request.Destination.BccAddresses = message.BccEmails
.Select(email => CoreHelpers.PunyEncode(email))
.ToList();
var replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail);
_globalSettings = globalSettings;
_hostingEnvironment = hostingEnvironment;
_logger = logger;
_client = amazonSimpleEmailService;
_source = $"\"{globalSettings.SiteName}\" <{replyToEmail}>";
_senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}";
if (!string.IsNullOrWhiteSpace(_globalSettings.Mail.AmazonConfigSetName))
{
_configSetName = _globalSettings.Mail.AmazonConfigSetName;
}
}
if (!string.IsNullOrWhiteSpace(message.Category))
public void Dispose()
{
request.Tags.Add(new MessageTag { Name = "Category", Value = message.Category });
_client?.Dispose();
}
try
public async Task SendEmailAsync(MailMessage message)
{
await SendAsync(request, false);
}
catch (Exception e)
{
_logger.LogWarning(e, "Failed to send email. Retrying...");
await SendAsync(request, true);
throw;
}
}
var request = new SendEmailRequest
{
ConfigurationSetName = _configSetName,
Source = _source,
Destination = new Destination
{
ToAddresses = message.ToEmails
.Select(email => CoreHelpers.PunyEncode(email))
.ToList()
},
Message = new Message
{
Subject = new Content(message.Subject),
Body = new Body
{
Html = new Content
{
Charset = "UTF-8",
Data = message.HtmlContent
},
Text = new Content
{
Charset = "UTF-8",
Data = message.TextContent
}
}
},
Tags = new List<MessageTag>
{
new MessageTag { Name = "Environment", Value = _hostingEnvironment.EnvironmentName },
new MessageTag { Name = "Sender", Value = _senderTag }
}
};
private async Task SendAsync(SendEmailRequest request, bool retry)
{
if (retry)
{
// wait and try again
await Task.Delay(2000);
if (message.BccEmails?.Any() ?? false)
{
request.Destination.BccAddresses = message.BccEmails
.Select(email => CoreHelpers.PunyEncode(email))
.ToList();
}
if (!string.IsNullOrWhiteSpace(message.Category))
{
request.Tags.Add(new MessageTag { Name = "Category", Value = message.Category });
}
try
{
await SendAsync(request, false);
}
catch (Exception e)
{
_logger.LogWarning(e, "Failed to send email. Retrying...");
await SendAsync(request, true);
throw;
}
}
private async Task SendAsync(SendEmailRequest request, bool retry)
{
if (retry)
{
// wait and try again
await Task.Delay(2000);
}
await _client.SendEmailAsync(request);
}
await _client.SendEmailAsync(request);
}
}

View File

@ -2,80 +2,81 @@
using Amazon.SQS;
using Bit.Core.Settings;
namespace Bit.Core.Services;
public class AmazonSqsBlockIpService : IBlockIpService, IDisposable
namespace Bit.Core.Services
{
private readonly IAmazonSQS _client;
private string _blockIpQueueUrl;
private string _unblockIpQueueUrl;
private bool _didInit = false;
private Tuple<string, bool, DateTime> _lastBlock;
public AmazonSqsBlockIpService(
GlobalSettings globalSettings)
: this(globalSettings, new AmazonSQSClient(
globalSettings.Amazon.AccessKeyId,
globalSettings.Amazon.AccessKeySecret,
RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region))
)
public class AmazonSqsBlockIpService : IBlockIpService, IDisposable
{
}
private readonly IAmazonSQS _client;
private string _blockIpQueueUrl;
private string _unblockIpQueueUrl;
private bool _didInit = false;
private Tuple<string, bool, DateTime> _lastBlock;
public AmazonSqsBlockIpService(
GlobalSettings globalSettings,
IAmazonSQS amazonSqs)
{
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId))
public AmazonSqsBlockIpService(
GlobalSettings globalSettings)
: this(globalSettings, new AmazonSQSClient(
globalSettings.Amazon.AccessKeyId,
globalSettings.Amazon.AccessKeySecret,
RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region))
)
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId));
}
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret))
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret));
}
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region))
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.Region));
}
_client = amazonSqs;
}
public void Dispose()
{
_client?.Dispose();
}
public async Task BlockIpAsync(string ipAddress, bool permanentBlock)
{
var now = DateTime.UtcNow;
if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock &&
(now - _lastBlock.Item3) < TimeSpan.FromMinutes(1))
public AmazonSqsBlockIpService(
GlobalSettings globalSettings,
IAmazonSQS amazonSqs)
{
// Already blocked this IP recently.
return;
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId))
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId));
}
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret))
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret));
}
if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region))
{
throw new ArgumentNullException(nameof(globalSettings.Amazon.Region));
}
_client = amazonSqs;
}
_lastBlock = new Tuple<string, bool, DateTime>(ipAddress, permanentBlock, now);
await _client.SendMessageAsync(_blockIpQueueUrl, ipAddress);
if (!permanentBlock)
public void Dispose()
{
await _client.SendMessageAsync(_unblockIpQueueUrl, ipAddress);
}
}
private async Task InitAsync()
{
if (_didInit)
{
return;
_client?.Dispose();
}
var blockIpQueue = await _client.GetQueueUrlAsync("block-ip");
_blockIpQueueUrl = blockIpQueue.QueueUrl;
var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip");
_unblockIpQueueUrl = unblockIpQueue.QueueUrl;
_didInit = true;
public async Task BlockIpAsync(string ipAddress, bool permanentBlock)
{
var now = DateTime.UtcNow;
if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock &&
(now - _lastBlock.Item3) < TimeSpan.FromMinutes(1))
{
// Already blocked this IP recently.
return;
}
_lastBlock = new Tuple<string, bool, DateTime>(ipAddress, permanentBlock, now);
await _client.SendMessageAsync(_blockIpQueueUrl, ipAddress);
if (!permanentBlock)
{
await _client.SendMessageAsync(_unblockIpQueueUrl, ipAddress);
}
}
private async Task InitAsync()
{
if (_didInit)
{
return;
}
var blockIpQueue = await _client.GetQueueUrlAsync("block-ip");
_blockIpQueueUrl = blockIpQueue.QueueUrl;
var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip");
_unblockIpQueueUrl = unblockIpQueue.QueueUrl;
_didInit = true;
}
}
}

View File

@ -7,126 +7,127 @@ using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class AppleIapService : IAppleIapService
namespace Bit.Core.Services
{
private readonly HttpClient _httpClient = new HttpClient();
private readonly GlobalSettings _globalSettings;
private readonly IWebHostEnvironment _hostingEnvironment;
private readonly IMetaDataRepository _metaDataRespository;
private readonly ILogger<AppleIapService> _logger;
public AppleIapService(
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
IMetaDataRepository metaDataRespository,
ILogger<AppleIapService> logger)
public class AppleIapService : IAppleIapService
{
_globalSettings = globalSettings;
_hostingEnvironment = hostingEnvironment;
_metaDataRespository = metaDataRespository;
_logger = logger;
}
private readonly HttpClient _httpClient = new HttpClient();
public async Task<AppleReceiptStatus> GetVerifiedReceiptStatusAsync(string receiptData)
{
var receiptStatus = await GetReceiptStatusAsync(receiptData);
if (receiptStatus?.Status != 0)
{
return null;
}
var validEnvironment = _globalSettings.AppleIap.AppInReview ||
(!(_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment == "Sandbox") ||
((_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment != "Sandbox");
var validProductBundle = receiptStatus.Receipt.BundleId == "com.bitwarden.desktop" ||
receiptStatus.Receipt.BundleId == "com.8bit.bitwarden";
var validProduct = receiptStatus.LatestReceiptInfo.LastOrDefault()?.ProductId == "premium_annually";
var validIds = receiptStatus.GetOriginalTransactionId() != null &&
receiptStatus.GetLastTransactionId() != null;
var validTransaction = receiptStatus.GetLastExpiresDate()
.GetValueOrDefault(DateTime.MinValue) > DateTime.UtcNow;
if (validEnvironment && validProductBundle && validProduct && validIds && validTransaction)
{
return receiptStatus;
}
return null;
}
private readonly GlobalSettings _globalSettings;
private readonly IWebHostEnvironment _hostingEnvironment;
private readonly IMetaDataRepository _metaDataRespository;
private readonly ILogger<AppleIapService> _logger;
public async Task SaveReceiptAsync(AppleReceiptStatus receiptStatus, Guid userId)
{
var originalTransactionId = receiptStatus.GetOriginalTransactionId();
if (string.IsNullOrWhiteSpace(originalTransactionId))
public AppleIapService(
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
IMetaDataRepository metaDataRespository,
ILogger<AppleIapService> logger)
{
throw new Exception("OriginalTransactionId is null");
_globalSettings = globalSettings;
_hostingEnvironment = hostingEnvironment;
_metaDataRespository = metaDataRespository;
_logger = logger;
}
await _metaDataRespository.UpsertAsync("AppleReceipt", originalTransactionId,
new Dictionary<string, string>
public async Task<AppleReceiptStatus> GetVerifiedReceiptStatusAsync(string receiptData)
{
var receiptStatus = await GetReceiptStatusAsync(receiptData);
if (receiptStatus?.Status != 0)
{
["Data"] = receiptStatus.GetReceiptData(),
["UserId"] = userId.ToString()
});
}
public async Task<Tuple<string, Guid?>> GetReceiptAsync(string originalTransactionId)
{
var receipt = await _metaDataRespository.GetAsync("AppleReceipt", originalTransactionId);
if (receipt == null)
{
return null;
}
return new Tuple<string, Guid?>(receipt.ContainsKey("Data") ? receipt["Data"] : null,
receipt.ContainsKey("UserId") ? new Guid(receipt["UserId"]) : (Guid?)null);
}
// Internal for testing
internal async Task<AppleReceiptStatus> GetReceiptStatusAsync(string receiptData, bool prod = true,
int attempt = 0, AppleReceiptStatus lastReceiptStatus = null)
{
try
{
if (attempt > 4)
{
throw new Exception(
$"Failed verifying Apple IAP after too many attempts. Last attempt status: {lastReceiptStatus?.Status.ToString() ?? "null"}");
return null;
}
var url = string.Format("https://{0}.itunes.apple.com/verifyReceipt", prod ? "buy" : "sandbox");
var response = await _httpClient.PostAsJsonAsync(url, new AppleVerifyReceiptRequestModel
var validEnvironment = _globalSettings.AppleIap.AppInReview ||
(!(_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment == "Sandbox") ||
((_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment != "Sandbox");
var validProductBundle = receiptStatus.Receipt.BundleId == "com.bitwarden.desktop" ||
receiptStatus.Receipt.BundleId == "com.8bit.bitwarden";
var validProduct = receiptStatus.LatestReceiptInfo.LastOrDefault()?.ProductId == "premium_annually";
var validIds = receiptStatus.GetOriginalTransactionId() != null &&
receiptStatus.GetLastTransactionId() != null;
var validTransaction = receiptStatus.GetLastExpiresDate()
.GetValueOrDefault(DateTime.MinValue) > DateTime.UtcNow;
if (validEnvironment && validProductBundle && validProduct && validIds && validTransaction)
{
ReceiptData = receiptData,
Password = _globalSettings.AppleIap.Password
});
if (response.IsSuccessStatusCode)
{
var receiptStatus = await response.Content.ReadFromJsonAsync<AppleReceiptStatus>();
if (receiptStatus.Status == 21007)
{
return await GetReceiptStatusAsync(receiptData, false, attempt + 1, receiptStatus);
}
else if (receiptStatus.Status == 21005)
{
await Task.Delay(2000);
return await GetReceiptStatusAsync(receiptData, prod, attempt + 1, receiptStatus);
}
return receiptStatus;
}
return null;
}
catch (Exception e)
public async Task SaveReceiptAsync(AppleReceiptStatus receiptStatus, Guid userId)
{
_logger.LogWarning(e, "Error verifying Apple IAP receipt.");
var originalTransactionId = receiptStatus.GetOriginalTransactionId();
if (string.IsNullOrWhiteSpace(originalTransactionId))
{
throw new Exception("OriginalTransactionId is null");
}
await _metaDataRespository.UpsertAsync("AppleReceipt", originalTransactionId,
new Dictionary<string, string>
{
["Data"] = receiptStatus.GetReceiptData(),
["UserId"] = userId.ToString()
});
}
return null;
public async Task<Tuple<string, Guid?>> GetReceiptAsync(string originalTransactionId)
{
var receipt = await _metaDataRespository.GetAsync("AppleReceipt", originalTransactionId);
if (receipt == null)
{
return null;
}
return new Tuple<string, Guid?>(receipt.ContainsKey("Data") ? receipt["Data"] : null,
receipt.ContainsKey("UserId") ? new Guid(receipt["UserId"]) : (Guid?)null);
}
// Internal for testing
internal async Task<AppleReceiptStatus> GetReceiptStatusAsync(string receiptData, bool prod = true,
int attempt = 0, AppleReceiptStatus lastReceiptStatus = null)
{
try
{
if (attempt > 4)
{
throw new Exception(
$"Failed verifying Apple IAP after too many attempts. Last attempt status: {lastReceiptStatus?.Status.ToString() ?? "null"}");
}
var url = string.Format("https://{0}.itunes.apple.com/verifyReceipt", prod ? "buy" : "sandbox");
var response = await _httpClient.PostAsJsonAsync(url, new AppleVerifyReceiptRequestModel
{
ReceiptData = receiptData,
Password = _globalSettings.AppleIap.Password
});
if (response.IsSuccessStatusCode)
{
var receiptStatus = await response.Content.ReadFromJsonAsync<AppleReceiptStatus>();
if (receiptStatus.Status == 21007)
{
return await GetReceiptStatusAsync(receiptData, false, attempt + 1, receiptStatus);
}
else if (receiptStatus.Status == 21005)
{
await Task.Delay(2000);
return await GetReceiptStatusAsync(receiptData, prod, attempt + 1, receiptStatus);
}
return receiptStatus;
}
}
catch (Exception e)
{
_logger.LogWarning(e, "Error verifying Apple IAP receipt.");
}
return null;
}
}
public class AppleVerifyReceiptRequestModel
{
[JsonPropertyName("receipt-data")]
public string ReceiptData { get; set; }
[JsonPropertyName("password")]
public string Password { get; set; }
}
}
public class AppleVerifyReceiptRequestModel
{
[JsonPropertyName("receipt-data")]
public string ReceiptData { get; set; }
[JsonPropertyName("password")]
public string Password { get; set; }
}

View File

@ -7,259 +7,260 @@ using Bit.Core.Models.Data;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class AzureAttachmentStorageService : IAttachmentStorageService
namespace Bit.Core.Services
{
public FileUploadType FileUploadType => FileUploadType.Azure;
public const string EventGridEnabledContainerName = "attachments-v2";
private const string _defaultContainerName = "attachments";
private readonly static string[] _attachmentContainerName = { "attachments", "attachments-v2" };
private static readonly TimeSpan blobLinkLiveTime = TimeSpan.FromMinutes(1);
private readonly BlobServiceClient _blobServiceClient;
private readonly Dictionary<string, BlobContainerClient> _attachmentContainers = new Dictionary<string, BlobContainerClient>();
private readonly ILogger<AzureAttachmentStorageService> _logger;
private string BlobName(Guid cipherId, CipherAttachment.MetaData attachmentData, Guid? organizationId = null, bool temp = false) =>
string.Concat(
temp ? "temp/" : "",
$"{cipherId}/",
organizationId != null ? $"{organizationId.Value}/" : "",
attachmentData.AttachmentId
);
public static (string cipherId, string organizationId, string attachmentId) IdentifiersFromBlobName(string blobName)
public class AzureAttachmentStorageService : IAttachmentStorageService
{
var parts = blobName.Split('/');
switch (parts.Length)
public FileUploadType FileUploadType => FileUploadType.Azure;
public const string EventGridEnabledContainerName = "attachments-v2";
private const string _defaultContainerName = "attachments";
private readonly static string[] _attachmentContainerName = { "attachments", "attachments-v2" };
private static readonly TimeSpan blobLinkLiveTime = TimeSpan.FromMinutes(1);
private readonly BlobServiceClient _blobServiceClient;
private readonly Dictionary<string, BlobContainerClient> _attachmentContainers = new Dictionary<string, BlobContainerClient>();
private readonly ILogger<AzureAttachmentStorageService> _logger;
private string BlobName(Guid cipherId, CipherAttachment.MetaData attachmentData, Guid? organizationId = null, bool temp = false) =>
string.Concat(
temp ? "temp/" : "",
$"{cipherId}/",
organizationId != null ? $"{organizationId.Value}/" : "",
attachmentData.AttachmentId
);
public static (string cipherId, string organizationId, string attachmentId) IdentifiersFromBlobName(string blobName)
{
case 4:
return (parts[1], parts[2], parts[3]);
case 3:
if (parts[0] == "temp")
{
return (parts[1], null, parts[2]);
}
else
{
return (parts[0], parts[1], parts[2]);
}
case 2:
return (parts[0], null, parts[1]);
default:
throw new Exception("Cannot determine cipher information from blob name");
}
}
public AzureAttachmentStorageService(
GlobalSettings globalSettings,
ILogger<AzureAttachmentStorageService> logger)
{
_blobServiceClient = new BlobServiceClient(globalSettings.Attachment.ConnectionString);
_logger = logger;
}
public async Task<string> GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData)
{
await InitAsync(attachmentData.ContainerName);
var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData));
var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(blobLinkLiveTime));
return sasUri.ToString();
}
public async Task<string> GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData)
{
await InitAsync(EventGridEnabledContainerName);
var blobClient = _attachmentContainers[EventGridEnabledContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData));
attachmentData.ContainerName = EventGridEnabledContainerName;
var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(blobLinkLiveTime));
return sasUri.ToString();
}
public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData)
{
attachmentData.ContainerName = _defaultContainerName;
await InitAsync(_defaultContainerName);
var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData));
var metadata = new Dictionary<string, string>();
metadata.Add("cipherId", cipher.Id.ToString());
if (cipher.UserId.HasValue)
{
metadata.Add("userId", cipher.UserId.Value.ToString());
}
else
{
metadata.Add("organizationId", cipher.OrganizationId.Value.ToString());
var parts = blobName.Split('/');
switch (parts.Length)
{
case 4:
return (parts[1], parts[2], parts[3]);
case 3:
if (parts[0] == "temp")
{
return (parts[1], null, parts[2]);
}
else
{
return (parts[0], parts[1], parts[2]);
}
case 2:
return (parts[0], null, parts[1]);
default:
throw new Exception("Cannot determine cipher information from blob name");
}
}
var headers = new BlobHttpHeaders
public AzureAttachmentStorageService(
GlobalSettings globalSettings,
ILogger<AzureAttachmentStorageService> logger)
{
ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\""
};
await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers });
}
public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData)
{
attachmentData.ContainerName = _defaultContainerName;
await InitAsync(_defaultContainerName);
var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient(
BlobName(cipherId, attachmentData, organizationId, temp: true));
var metadata = new Dictionary<string, string>();
metadata.Add("cipherId", cipherId.ToString());
metadata.Add("organizationId", organizationId.ToString());
var headers = new BlobHttpHeaders
{
ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\""
};
await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers });
}
public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData data)
{
await InitAsync(data.ContainerName);
var source = _attachmentContainers[data.ContainerName].GetBlobClient(
BlobName(cipherId, data, organizationId, temp: true));
if (!await source.ExistsAsync())
{
return;
_blobServiceClient = new BlobServiceClient(globalSettings.Attachment.ConnectionString);
_logger = logger;
}
await InitAsync(_defaultContainerName);
var dest = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipherId, data));
if (!await dest.ExistsAsync())
public async Task<string> GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData)
{
return;
await InitAsync(attachmentData.ContainerName);
var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData));
var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(blobLinkLiveTime));
return sasUri.ToString();
}
var original = _attachmentContainers[_defaultContainerName].GetBlobClient(
BlobName(cipherId, data, temp: true));
await original.DeleteIfExistsAsync();
await original.StartCopyFromUriAsync(dest.Uri);
await dest.DeleteIfExistsAsync();
await dest.StartCopyFromUriAsync(source.Uri);
}
public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer)
{
await InitAsync(attachmentData.ContainerName);
var source = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(
BlobName(cipherId, attachmentData, organizationId, temp: true));
await source.DeleteIfExistsAsync();
await InitAsync(originalContainer);
var original = _attachmentContainers[originalContainer].GetBlobClient(
BlobName(cipherId, attachmentData, temp: true));
if (!await original.ExistsAsync())
public async Task<string> GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData)
{
return;
await InitAsync(EventGridEnabledContainerName);
var blobClient = _attachmentContainers[EventGridEnabledContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData));
attachmentData.ContainerName = EventGridEnabledContainerName;
var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(blobLinkLiveTime));
return sasUri.ToString();
}
var dest = _attachmentContainers[originalContainer].GetBlobClient(
BlobName(cipherId, attachmentData));
await dest.DeleteIfExistsAsync();
await dest.StartCopyFromUriAsync(original.Uri);
await original.DeleteIfExistsAsync();
}
public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData)
{
await InitAsync(attachmentData.ContainerName);
var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(
BlobName(cipherId, attachmentData));
await blobClient.DeleteIfExistsAsync();
}
public async Task CleanupAsync(Guid cipherId) => await DeleteAttachmentsForPathAsync($"temp/{cipherId}");
public async Task DeleteAttachmentsForCipherAsync(Guid cipherId) =>
await DeleteAttachmentsForPathAsync(cipherId.ToString());
public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId)
{
await InitAsync(_defaultContainerName);
}
public async Task DeleteAttachmentsForUserAsync(Guid userId)
{
await InitAsync(_defaultContainerName);
}
public async Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway)
{
await InitAsync(attachmentData.ContainerName);
var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData));
try
public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData)
{
var blobProperties = await blobClient.GetPropertiesAsync();
attachmentData.ContainerName = _defaultContainerName;
await InitAsync(_defaultContainerName);
var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData));
var metadata = blobProperties.Value.Metadata;
metadata["cipherId"] = cipher.Id.ToString();
var metadata = new Dictionary<string, string>();
metadata.Add("cipherId", cipher.Id.ToString());
if (cipher.UserId.HasValue)
{
metadata["userId"] = cipher.UserId.Value.ToString();
metadata.Add("userId", cipher.UserId.Value.ToString());
}
else
{
metadata["organizationId"] = cipher.OrganizationId.Value.ToString();
metadata.Add("organizationId", cipher.OrganizationId.Value.ToString());
}
await blobClient.SetMetadataAsync(metadata);
var headers = new BlobHttpHeaders
{
ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\""
};
await blobClient.SetHttpHeadersAsync(headers);
await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers });
}
var length = blobProperties.Value.ContentLength;
if (length < attachmentData.Size - leeway || length > attachmentData.Size + leeway)
public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData)
{
attachmentData.ContainerName = _defaultContainerName;
await InitAsync(_defaultContainerName);
var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient(
BlobName(cipherId, attachmentData, organizationId, temp: true));
var metadata = new Dictionary<string, string>();
metadata.Add("cipherId", cipherId.ToString());
metadata.Add("organizationId", organizationId.ToString());
var headers = new BlobHttpHeaders
{
return (false, length);
ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\""
};
await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers });
}
public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData data)
{
await InitAsync(data.ContainerName);
var source = _attachmentContainers[data.ContainerName].GetBlobClient(
BlobName(cipherId, data, organizationId, temp: true));
if (!await source.ExistsAsync())
{
return;
}
return (true, length);
}
catch (Exception ex)
{
_logger.LogError(ex, "Unhandled error in ValidateFileAsync");
return (false, null);
}
}
private async Task DeleteAttachmentsForPathAsync(string path)
{
foreach (var container in _attachmentContainerName)
{
await InitAsync(container);
var blobContainerClient = _attachmentContainers[container];
var blobItems = blobContainerClient.GetBlobsAsync(BlobTraits.None, BlobStates.None, prefix: path);
await foreach (var blobItem in blobItems)
await InitAsync(_defaultContainerName);
var dest = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipherId, data));
if (!await dest.ExistsAsync())
{
BlobClient blobClient = blobContainerClient.GetBlobClient(blobItem.Name);
await blobClient.DeleteIfExistsAsync();
return;
}
var original = _attachmentContainers[_defaultContainerName].GetBlobClient(
BlobName(cipherId, data, temp: true));
await original.DeleteIfExistsAsync();
await original.StartCopyFromUriAsync(dest.Uri);
await dest.DeleteIfExistsAsync();
await dest.StartCopyFromUriAsync(source.Uri);
}
public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer)
{
await InitAsync(attachmentData.ContainerName);
var source = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(
BlobName(cipherId, attachmentData, organizationId, temp: true));
await source.DeleteIfExistsAsync();
await InitAsync(originalContainer);
var original = _attachmentContainers[originalContainer].GetBlobClient(
BlobName(cipherId, attachmentData, temp: true));
if (!await original.ExistsAsync())
{
return;
}
var dest = _attachmentContainers[originalContainer].GetBlobClient(
BlobName(cipherId, attachmentData));
await dest.DeleteIfExistsAsync();
await dest.StartCopyFromUriAsync(original.Uri);
await original.DeleteIfExistsAsync();
}
public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData)
{
await InitAsync(attachmentData.ContainerName);
var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(
BlobName(cipherId, attachmentData));
await blobClient.DeleteIfExistsAsync();
}
public async Task CleanupAsync(Guid cipherId) => await DeleteAttachmentsForPathAsync($"temp/{cipherId}");
public async Task DeleteAttachmentsForCipherAsync(Guid cipherId) =>
await DeleteAttachmentsForPathAsync(cipherId.ToString());
public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId)
{
await InitAsync(_defaultContainerName);
}
public async Task DeleteAttachmentsForUserAsync(Guid userId)
{
await InitAsync(_defaultContainerName);
}
public async Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway)
{
await InitAsync(attachmentData.ContainerName);
var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData));
try
{
var blobProperties = await blobClient.GetPropertiesAsync();
var metadata = blobProperties.Value.Metadata;
metadata["cipherId"] = cipher.Id.ToString();
if (cipher.UserId.HasValue)
{
metadata["userId"] = cipher.UserId.Value.ToString();
}
else
{
metadata["organizationId"] = cipher.OrganizationId.Value.ToString();
}
await blobClient.SetMetadataAsync(metadata);
var headers = new BlobHttpHeaders
{
ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\""
};
await blobClient.SetHttpHeadersAsync(headers);
var length = blobProperties.Value.ContentLength;
if (length < attachmentData.Size - leeway || length > attachmentData.Size + leeway)
{
return (false, length);
}
return (true, length);
}
catch (Exception ex)
{
_logger.LogError(ex, "Unhandled error in ValidateFileAsync");
return (false, null);
}
}
}
private async Task InitAsync(string containerName)
{
if (!_attachmentContainers.ContainsKey(containerName) || _attachmentContainers[containerName] == null)
private async Task DeleteAttachmentsForPathAsync(string path)
{
_attachmentContainers[containerName] = _blobServiceClient.GetBlobContainerClient(containerName);
if (containerName == "attachments")
foreach (var container in _attachmentContainerName)
{
await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.Blob, null, null);
await InitAsync(container);
var blobContainerClient = _attachmentContainers[container];
var blobItems = blobContainerClient.GetBlobsAsync(BlobTraits.None, BlobStates.None, prefix: path);
await foreach (var blobItem in blobItems)
{
BlobClient blobClient = blobContainerClient.GetBlobClient(blobItem.Name);
await blobClient.DeleteIfExistsAsync();
}
}
else
}
private async Task InitAsync(string containerName)
{
if (!_attachmentContainers.ContainsKey(containerName) || _attachmentContainers[containerName] == null)
{
await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.None, null, null);
_attachmentContainers[containerName] = _blobServiceClient.GetBlobContainerClient(containerName);
if (containerName == "attachments")
{
await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.Blob, null, null);
}
else
{
await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.None, null, null);
}
}
}
}

View File

@ -1,36 +1,37 @@
using Azure.Storage.Queues;
using Bit.Core.Settings;
namespace Bit.Core.Services;
public class AzureQueueBlockIpService : IBlockIpService
namespace Bit.Core.Services
{
private readonly QueueClient _blockIpQueueClient;
private readonly QueueClient _unblockIpQueueClient;
private Tuple<string, bool, DateTime> _lastBlock;
public AzureQueueBlockIpService(
GlobalSettings globalSettings)
public class AzureQueueBlockIpService : IBlockIpService
{
_blockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "blockip");
_unblockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "unblockip");
}
private readonly QueueClient _blockIpQueueClient;
private readonly QueueClient _unblockIpQueueClient;
private Tuple<string, bool, DateTime> _lastBlock;
public async Task BlockIpAsync(string ipAddress, bool permanentBlock)
{
var now = DateTime.UtcNow;
if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock &&
(now - _lastBlock.Item3) < TimeSpan.FromMinutes(1))
public AzureQueueBlockIpService(
GlobalSettings globalSettings)
{
// Already blocked this IP recently.
return;
_blockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "blockip");
_unblockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "unblockip");
}
_lastBlock = new Tuple<string, bool, DateTime>(ipAddress, permanentBlock, now);
await _blockIpQueueClient.SendMessageAsync(ipAddress);
if (!permanentBlock)
public async Task BlockIpAsync(string ipAddress, bool permanentBlock)
{
await _unblockIpQueueClient.SendMessageAsync(ipAddress, new TimeSpan(0, 15, 0));
var now = DateTime.UtcNow;
if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock &&
(now - _lastBlock.Item3) < TimeSpan.FromMinutes(1))
{
// Already blocked this IP recently.
return;
}
_lastBlock = new Tuple<string, bool, DateTime>(ipAddress, permanentBlock, now);
await _blockIpQueueClient.SendMessageAsync(ipAddress);
if (!permanentBlock)
{
await _unblockIpQueueClient.SendMessageAsync(ipAddress, new TimeSpan(0, 15, 0));
}
}
}
}

View File

@ -3,14 +3,15 @@ using Bit.Core.Models.Data;
using Bit.Core.Settings;
using Bit.Core.Utilities;
namespace Bit.Core.Services;
public class AzureQueueEventWriteService : AzureQueueService<IEvent>, IEventWriteService
namespace Bit.Core.Services
{
public AzureQueueEventWriteService(GlobalSettings globalSettings) : base(
new QueueClient(globalSettings.Events.ConnectionString, "event"),
JsonHelpers.IgnoreWritingNull)
{ }
public class AzureQueueEventWriteService : AzureQueueService<IEvent>, IEventWriteService
{
public AzureQueueEventWriteService(GlobalSettings globalSettings) : base(
new QueueClient(globalSettings.Events.ConnectionString, "event"),
JsonHelpers.IgnoreWritingNull)
{ }
public Task CreateAsync(IEvent e) => CreateManyAsync(new[] { e });
public Task CreateAsync(IEvent e) => CreateManyAsync(new[] { e });
}
}

View File

@ -3,18 +3,19 @@ using Bit.Core.Models.Mail;
using Bit.Core.Settings;
using Bit.Core.Utilities;
namespace Bit.Core.Services;
public class AzureQueueMailService : AzureQueueService<IMailQueueMessage>, IMailEnqueuingService
namespace Bit.Core.Services
{
public AzureQueueMailService(GlobalSettings globalSettings) : base(
new QueueClient(globalSettings.Mail.ConnectionString, "mail"),
JsonHelpers.IgnoreWritingNull)
{ }
public class AzureQueueMailService : AzureQueueService<IMailQueueMessage>, IMailEnqueuingService
{
public AzureQueueMailService(GlobalSettings globalSettings) : base(
new QueueClient(globalSettings.Mail.ConnectionString, "mail"),
JsonHelpers.IgnoreWritingNull)
{ }
public Task EnqueueAsync(IMailQueueMessage message, Func<IMailQueueMessage, Task> fallback) =>
CreateManyAsync(new[] { message });
public Task EnqueueAsync(IMailQueueMessage message, Func<IMailQueueMessage, Task> fallback) =>
CreateManyAsync(new[] { message });
public Task EnqueueManyAsync(IEnumerable<IMailQueueMessage> messages, Func<IMailQueueMessage, Task> fallback) =>
CreateManyAsync(messages);
public Task EnqueueManyAsync(IEnumerable<IMailQueueMessage> messages, Func<IMailQueueMessage, Task> fallback) =>
CreateManyAsync(messages);
}
}

View File

@ -8,189 +8,190 @@ using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Http;
namespace Bit.Core.Services;
public class AzureQueuePushNotificationService : IPushNotificationService
namespace Bit.Core.Services
{
private readonly QueueClient _queueClient;
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
public AzureQueuePushNotificationService(
GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor)
public class AzureQueuePushNotificationService : IPushNotificationService
{
_queueClient = new QueueClient(globalSettings.Notifications.ConnectionString, "notifications");
_globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor;
}
private readonly QueueClient _queueClient;
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds);
}
public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds);
}
public async Task PushSyncCipherDeleteAsync(Cipher cipher)
{
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
{
if (cipher.OrganizationId.HasValue)
public AzureQueuePushNotificationService(
GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor)
{
var message = new SyncCipherPushNotification
_queueClient = new QueueClient(globalSettings.Notifications.ConnectionString, "notifications");
_globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor;
}
public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds);
}
public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds);
}
public async Task PushSyncCipherDeleteAsync(Cipher cipher)
{
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
{
if (cipher.OrganizationId.HasValue)
{
Id = cipher.Id,
OrganizationId = cipher.OrganizationId,
RevisionDate = cipher.RevisionDate,
CollectionIds = collectionIds,
var message = new SyncCipherPushNotification
{
Id = cipher.Id,
OrganizationId = cipher.OrganizationId,
RevisionDate = cipher.RevisionDate,
CollectionIds = collectionIds,
};
await SendMessageAsync(type, message, true);
}
else if (cipher.UserId.HasValue)
{
var message = new SyncCipherPushNotification
{
Id = cipher.Id,
UserId = cipher.UserId,
RevisionDate = cipher.RevisionDate,
};
await SendMessageAsync(type, message, true);
}
}
public async Task PushSyncFolderCreateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderCreate);
}
public async Task PushSyncFolderUpdateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderUpdate);
}
public async Task PushSyncFolderDeleteAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderDelete);
}
private async Task PushFolderAsync(Folder folder, PushType type)
{
var message = new SyncFolderPushNotification
{
Id = folder.Id,
UserId = folder.UserId,
RevisionDate = folder.RevisionDate
};
await SendMessageAsync(type, message, true);
}
else if (cipher.UserId.HasValue)
public async Task PushSyncCiphersAsync(Guid userId)
{
var message = new SyncCipherPushNotification
await PushUserAsync(userId, PushType.SyncCiphers);
}
public async Task PushSyncVaultAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncVault);
}
public async Task PushSyncOrgKeysAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncOrgKeys);
}
public async Task PushSyncSettingsAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncSettings);
}
public async Task PushLogOutAsync(Guid userId)
{
await PushUserAsync(userId, PushType.LogOut);
}
private async Task PushUserAsync(Guid userId, PushType type)
{
var message = new UserPushNotification
{
Id = cipher.Id,
UserId = cipher.UserId,
RevisionDate = cipher.RevisionDate,
UserId = userId,
Date = DateTime.UtcNow
};
await SendMessageAsync(type, message, true);
await SendMessageAsync(type, message, false);
}
}
public async Task PushSyncFolderCreateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderCreate);
}
public async Task PushSyncFolderUpdateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderUpdate);
}
public async Task PushSyncFolderDeleteAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderDelete);
}
private async Task PushFolderAsync(Folder folder, PushType type)
{
var message = new SyncFolderPushNotification
public async Task PushSyncSendCreateAsync(Send send)
{
Id = folder.Id,
UserId = folder.UserId,
RevisionDate = folder.RevisionDate
};
await PushSendAsync(send, PushType.SyncSendCreate);
}
await SendMessageAsync(type, message, true);
}
public async Task PushSyncCiphersAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncCiphers);
}
public async Task PushSyncVaultAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncVault);
}
public async Task PushSyncOrgKeysAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncOrgKeys);
}
public async Task PushSyncSettingsAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncSettings);
}
public async Task PushLogOutAsync(Guid userId)
{
await PushUserAsync(userId, PushType.LogOut);
}
private async Task PushUserAsync(Guid userId, PushType type)
{
var message = new UserPushNotification
public async Task PushSyncSendUpdateAsync(Send send)
{
UserId = userId,
Date = DateTime.UtcNow
};
await PushSendAsync(send, PushType.SyncSendUpdate);
}
await SendMessageAsync(type, message, false);
}
public async Task PushSyncSendCreateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendCreate);
}
public async Task PushSyncSendUpdateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendUpdate);
}
public async Task PushSyncSendDeleteAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendDelete);
}
private async Task PushSendAsync(Send send, PushType type)
{
if (send.UserId.HasValue)
public async Task PushSyncSendDeleteAsync(Send send)
{
var message = new SyncSendPushNotification
await PushSendAsync(send, PushType.SyncSendDelete);
}
private async Task PushSendAsync(Send send, PushType type)
{
if (send.UserId.HasValue)
{
Id = send.Id,
UserId = send.UserId.Value,
RevisionDate = send.RevisionDate
};
var message = new SyncSendPushNotification
{
Id = send.Id,
UserId = send.UserId.Value,
RevisionDate = send.RevisionDate
};
await SendMessageAsync(type, message, true);
await SendMessageAsync(type, message, true);
}
}
}
private async Task SendMessageAsync<T>(PushType type, T payload, bool excludeCurrentContext)
{
var contextId = GetContextIdentifier(excludeCurrentContext);
var message = JsonSerializer.Serialize(new PushNotificationData<T>(type, payload, contextId),
JsonHelpers.IgnoreWritingNull);
await _queueClient.SendMessageAsync(message);
}
private string GetContextIdentifier(bool excludeCurrentContext)
{
if (!excludeCurrentContext)
private async Task SendMessageAsync<T>(PushType type, T payload, bool excludeCurrentContext)
{
return null;
var contextId = GetContextIdentifier(excludeCurrentContext);
var message = JsonSerializer.Serialize(new PushNotificationData<T>(type, payload, contextId),
JsonHelpers.IgnoreWritingNull);
await _queueClient.SendMessageAsync(message);
}
var currentContext = _httpContextAccessor?.HttpContext?.
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier;
}
private string GetContextIdentifier(bool excludeCurrentContext)
{
if (!excludeCurrentContext)
{
return null;
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
// Noop
return Task.FromResult(0);
}
var currentContext = _httpContextAccessor?.HttpContext?.
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier;
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
// Noop
return Task.FromResult(0);
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
// Noop
return Task.FromResult(0);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
// Noop
return Task.FromResult(0);
}
}
}

View File

@ -5,44 +5,45 @@ using Bit.Core.Models.Business;
using Bit.Core.Settings;
using Bit.Core.Utilities;
namespace Bit.Core.Services;
public class AzureQueueReferenceEventService : IReferenceEventService
namespace Bit.Core.Services
{
private const string _queueName = "reference-events";
private readonly QueueClient _queueClient;
private readonly GlobalSettings _globalSettings;
public AzureQueueReferenceEventService(
GlobalSettings globalSettings)
public class AzureQueueReferenceEventService : IReferenceEventService
{
_queueClient = new QueueClient(globalSettings.Events.ConnectionString, _queueName);
_globalSettings = globalSettings;
}
private const string _queueName = "reference-events";
public async Task RaiseEventAsync(ReferenceEvent referenceEvent)
{
await SendMessageAsync(referenceEvent);
}
private readonly QueueClient _queueClient;
private readonly GlobalSettings _globalSettings;
private async Task SendMessageAsync(ReferenceEvent referenceEvent)
{
if (_globalSettings.SelfHosted)
public AzureQueueReferenceEventService(
GlobalSettings globalSettings)
{
// Ignore for self-hosted
return;
_queueClient = new QueueClient(globalSettings.Events.ConnectionString, _queueName);
_globalSettings = globalSettings;
}
try
public async Task RaiseEventAsync(ReferenceEvent referenceEvent)
{
var message = JsonSerializer.Serialize(referenceEvent, JsonHelpers.IgnoreWritingNullAndCamelCase);
// Messages need to be base64 encoded
var encodedMessage = Convert.ToBase64String(Encoding.UTF8.GetBytes(message));
await _queueClient.SendMessageAsync(encodedMessage);
await SendMessageAsync(referenceEvent);
}
catch
private async Task SendMessageAsync(ReferenceEvent referenceEvent)
{
// Ignore failure
if (_globalSettings.SelfHosted)
{
// Ignore for self-hosted
return;
}
try
{
var message = JsonSerializer.Serialize(referenceEvent, JsonHelpers.IgnoreWritingNullAndCamelCase);
// Messages need to be base64 encoded
var encodedMessage = Convert.ToBase64String(Encoding.UTF8.GetBytes(message));
await _queueClient.SendMessageAsync(encodedMessage);
}
catch
{
// Ignore failure
}
}
}
}

View File

@ -3,75 +3,76 @@ using System.Text.Json;
using Azure.Storage.Queues;
using Bit.Core.Utilities;
namespace Bit.Core.Services;
public abstract class AzureQueueService<T>
namespace Bit.Core.Services
{
protected QueueClient _queueClient;
protected JsonSerializerOptions _jsonOptions;
protected AzureQueueService(QueueClient queueClient, JsonSerializerOptions jsonOptions)
public abstract class AzureQueueService<T>
{
_queueClient = queueClient;
_jsonOptions = jsonOptions;
}
protected QueueClient _queueClient;
protected JsonSerializerOptions _jsonOptions;
public async Task CreateManyAsync(IEnumerable<T> messages)
{
if (messages?.Any() != true)
protected AzureQueueService(QueueClient queueClient, JsonSerializerOptions jsonOptions)
{
return;
_queueClient = queueClient;
_jsonOptions = jsonOptions;
}
foreach (var json in SerializeMany(messages, _jsonOptions))
public async Task CreateManyAsync(IEnumerable<T> messages)
{
await _queueClient.SendMessageAsync(json);
}
}
protected IEnumerable<string> SerializeMany(IEnumerable<T> messages, JsonSerializerOptions jsonOptions)
{
// Calculate Base-64 encoded text with padding
int getBase64Size(int byteCount) => ((4 * byteCount / 3) + 3) & ~3;
var messagesList = new List<string>();
var messagesListSize = 0;
int calculateByteSize(int totalSize, int toAdd) =>
// Calculate the total length this would be w/ "[]" and commas
getBase64Size(totalSize + toAdd + messagesList.Count + 2);
// Format the final array string, i.e. [{...},{...}]
string getArrayString()
{
if (messagesList.Count == 1)
if (messages?.Any() != true)
{
return CoreHelpers.Base64EncodeString(messagesList[0]);
return;
}
foreach (var json in SerializeMany(messages, _jsonOptions))
{
await _queueClient.SendMessageAsync(json);
}
return CoreHelpers.Base64EncodeString(
string.Concat("[", string.Join(',', messagesList), "]"));
}
var serializedMessages = messages.Select(message =>
JsonSerializer.Serialize(message, jsonOptions));
foreach (var message in serializedMessages)
protected IEnumerable<string> SerializeMany(IEnumerable<T> messages, JsonSerializerOptions jsonOptions)
{
var messageSize = Encoding.UTF8.GetByteCount(message);
if (calculateByteSize(messagesListSize, messageSize) > _queueClient.MessageMaxBytes)
// Calculate Base-64 encoded text with padding
int getBase64Size(int byteCount) => ((4 * byteCount / 3) + 3) & ~3;
var messagesList = new List<string>();
var messagesListSize = 0;
int calculateByteSize(int totalSize, int toAdd) =>
// Calculate the total length this would be w/ "[]" and commas
getBase64Size(totalSize + toAdd + messagesList.Count + 2);
// Format the final array string, i.e. [{...},{...}]
string getArrayString()
{
if (messagesList.Count == 1)
{
return CoreHelpers.Base64EncodeString(messagesList[0]);
}
return CoreHelpers.Base64EncodeString(
string.Concat("[", string.Join(',', messagesList), "]"));
}
var serializedMessages = messages.Select(message =>
JsonSerializer.Serialize(message, jsonOptions));
foreach (var message in serializedMessages)
{
var messageSize = Encoding.UTF8.GetByteCount(message);
if (calculateByteSize(messagesListSize, messageSize) > _queueClient.MessageMaxBytes)
{
yield return getArrayString();
messagesListSize = 0;
messagesList.Clear();
}
messagesList.Add(message);
messagesListSize += messageSize;
}
if (messagesList.Any())
{
yield return getArrayString();
messagesListSize = 0;
messagesList.Clear();
}
messagesList.Add(message);
messagesListSize += messageSize;
}
if (messagesList.Any())
{
yield return getArrayString();
}
}
}

View File

@ -6,136 +6,137 @@ using Bit.Core.Enums;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class AzureSendFileStorageService : ISendFileStorageService
namespace Bit.Core.Services
{
public const string FilesContainerName = "sendfiles";
private static readonly TimeSpan _downloadLinkLiveTime = TimeSpan.FromMinutes(1);
private readonly BlobServiceClient _blobServiceClient;
private readonly ILogger<AzureSendFileStorageService> _logger;
private BlobContainerClient _sendFilesContainerClient;
public FileUploadType FileUploadType => FileUploadType.Azure;
public static string SendIdFromBlobName(string blobName) => blobName.Split('/')[0];
public static string BlobName(Send send, string fileId) => $"{send.Id}/{fileId}";
public AzureSendFileStorageService(
GlobalSettings globalSettings,
ILogger<AzureSendFileStorageService> logger)
public class AzureSendFileStorageService : ISendFileStorageService
{
_blobServiceClient = new BlobServiceClient(globalSettings.Send.ConnectionString);
_logger = logger;
}
public const string FilesContainerName = "sendfiles";
private static readonly TimeSpan _downloadLinkLiveTime = TimeSpan.FromMinutes(1);
private readonly BlobServiceClient _blobServiceClient;
private readonly ILogger<AzureSendFileStorageService> _logger;
private BlobContainerClient _sendFilesContainerClient;
public async Task UploadNewFileAsync(Stream stream, Send send, string fileId)
{
await InitAsync();
public FileUploadType FileUploadType => FileUploadType.Azure;
var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId));
public static string SendIdFromBlobName(string blobName) => blobName.Split('/')[0];
public static string BlobName(Send send, string fileId) => $"{send.Id}/{fileId}";
var metadata = new Dictionary<string, string>();
if (send.UserId.HasValue)
public AzureSendFileStorageService(
GlobalSettings globalSettings,
ILogger<AzureSendFileStorageService> logger)
{
metadata.Add("userId", send.UserId.Value.ToString());
}
else
{
metadata.Add("organizationId", send.OrganizationId.Value.ToString());
_blobServiceClient = new BlobServiceClient(globalSettings.Send.ConnectionString);
_logger = logger;
}
var headers = new BlobHttpHeaders
public async Task UploadNewFileAsync(Stream stream, Send send, string fileId)
{
ContentDisposition = $"attachment; filename=\"{fileId}\""
};
await InitAsync();
await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers });
}
public async Task DeleteFileAsync(Send send, string fileId) => await DeleteBlobAsync(BlobName(send, fileId));
public async Task DeleteBlobAsync(string blobName)
{
await InitAsync();
var blobClient = _sendFilesContainerClient.GetBlobClient(blobName);
await blobClient.DeleteIfExistsAsync();
}
public async Task DeleteFilesForOrganizationAsync(Guid organizationId)
{
await InitAsync();
}
public async Task DeleteFilesForUserAsync(Guid userId)
{
await InitAsync();
}
public async Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId)
{
await InitAsync();
var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId));
var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(_downloadLinkLiveTime));
return sasUri.ToString();
}
public async Task<string> GetSendFileUploadUrlAsync(Send send, string fileId)
{
await InitAsync();
var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId));
var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(_downloadLinkLiveTime));
return sasUri.ToString();
}
public async Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway)
{
await InitAsync();
var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId));
try
{
var blobProperties = await blobClient.GetPropertiesAsync();
var metadata = blobProperties.Value.Metadata;
var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId));
var metadata = new Dictionary<string, string>();
if (send.UserId.HasValue)
{
metadata["userId"] = send.UserId.Value.ToString();
metadata.Add("userId", send.UserId.Value.ToString());
}
else
{
metadata["organizationId"] = send.OrganizationId.Value.ToString();
metadata.Add("organizationId", send.OrganizationId.Value.ToString());
}
await blobClient.SetMetadataAsync(metadata);
var headers = new BlobHttpHeaders
{
ContentDisposition = $"attachment; filename=\"{fileId}\""
};
await blobClient.SetHttpHeadersAsync(headers);
var length = blobProperties.Value.ContentLength;
if (length < expectedFileSize - leeway || length > expectedFileSize + leeway)
await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers });
}
public async Task DeleteFileAsync(Send send, string fileId) => await DeleteBlobAsync(BlobName(send, fileId));
public async Task DeleteBlobAsync(string blobName)
{
await InitAsync();
var blobClient = _sendFilesContainerClient.GetBlobClient(blobName);
await blobClient.DeleteIfExistsAsync();
}
public async Task DeleteFilesForOrganizationAsync(Guid organizationId)
{
await InitAsync();
}
public async Task DeleteFilesForUserAsync(Guid userId)
{
await InitAsync();
}
public async Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId)
{
await InitAsync();
var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId));
var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(_downloadLinkLiveTime));
return sasUri.ToString();
}
public async Task<string> GetSendFileUploadUrlAsync(Send send, string fileId)
{
await InitAsync();
var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId));
var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(_downloadLinkLiveTime));
return sasUri.ToString();
}
public async Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway)
{
await InitAsync();
var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId));
try
{
return (false, length);
var blobProperties = await blobClient.GetPropertiesAsync();
var metadata = blobProperties.Value.Metadata;
if (send.UserId.HasValue)
{
metadata["userId"] = send.UserId.Value.ToString();
}
else
{
metadata["organizationId"] = send.OrganizationId.Value.ToString();
}
await blobClient.SetMetadataAsync(metadata);
var headers = new BlobHttpHeaders
{
ContentDisposition = $"attachment; filename=\"{fileId}\""
};
await blobClient.SetHttpHeadersAsync(headers);
var length = blobProperties.Value.ContentLength;
if (length < expectedFileSize - leeway || length > expectedFileSize + leeway)
{
return (false, length);
}
return (true, length);
}
catch (Exception ex)
{
_logger.LogError(ex, "Unhandled error in ValidateFileAsync");
return (false, null);
}
return (true, length);
}
catch (Exception ex)
{
_logger.LogError(ex, "Unhandled error in ValidateFileAsync");
return (false, null);
}
}
private async Task InitAsync()
{
if (_sendFilesContainerClient == null)
private async Task InitAsync()
{
_sendFilesContainerClient = _blobServiceClient.GetBlobContainerClient(FilesContainerName);
await _sendFilesContainerClient.CreateIfNotExistsAsync(PublicAccessType.None, null, null);
if (_sendFilesContainerClient == null)
{
_sendFilesContainerClient = _blobServiceClient.GetBlobContainerClient(FilesContainerName);
await _sendFilesContainerClient.CreateIfNotExistsAsync(PublicAccessType.None, null, null);
}
}
}
}

View File

@ -5,202 +5,203 @@ using System.Text.Json;
using Bit.Core.Utilities;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public abstract class BaseIdentityClientService : IDisposable
namespace Bit.Core.Services
{
private readonly IHttpClientFactory _httpFactory;
private readonly string _identityScope;
private readonly string _identityClientId;
private readonly string _identityClientSecret;
protected readonly ILogger<BaseIdentityClientService> _logger;
private JsonDocument _decodedToken;
private DateTime? _nextAuthAttempt = null;
public BaseIdentityClientService(
IHttpClientFactory httpFactory,
string baseClientServerUri,
string baseIdentityServerUri,
string identityScope,
string identityClientId,
string identityClientSecret,
ILogger<BaseIdentityClientService> logger)
public abstract class BaseIdentityClientService : IDisposable
{
_httpFactory = httpFactory;
_identityScope = identityScope;
_identityClientId = identityClientId;
_identityClientSecret = identityClientSecret;
_logger = logger;
private readonly IHttpClientFactory _httpFactory;
private readonly string _identityScope;
private readonly string _identityClientId;
private readonly string _identityClientSecret;
protected readonly ILogger<BaseIdentityClientService> _logger;
Client = _httpFactory.CreateClient("client");
Client.BaseAddress = new Uri(baseClientServerUri);
Client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
private JsonDocument _decodedToken;
private DateTime? _nextAuthAttempt = null;
IdentityClient = _httpFactory.CreateClient("identity");
IdentityClient.BaseAddress = new Uri(baseIdentityServerUri);
IdentityClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
}
protected HttpClient Client { get; private set; }
protected HttpClient IdentityClient { get; private set; }
protected string AccessToken { get; private set; }
protected Task SendAsync(HttpMethod method, string path) =>
SendAsync<object, object>(method, path, null);
protected Task SendAsync<TRequest>(HttpMethod method, string path, TRequest body) =>
SendAsync<TRequest, object>(method, path, body);
protected async Task<TResult> SendAsync<TRequest, TResult>(HttpMethod method, string path, TRequest requestModel)
{
var tokenStateResponse = await HandleTokenStateAsync();
if (!tokenStateResponse)
public BaseIdentityClientService(
IHttpClientFactory httpFactory,
string baseClientServerUri,
string baseIdentityServerUri,
string identityScope,
string identityClientId,
string identityClientSecret,
ILogger<BaseIdentityClientService> logger)
{
return default;
_httpFactory = httpFactory;
_identityScope = identityScope;
_identityClientId = identityClientId;
_identityClientSecret = identityClientSecret;
_logger = logger;
Client = _httpFactory.CreateClient("client");
Client.BaseAddress = new Uri(baseClientServerUri);
Client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
IdentityClient = _httpFactory.CreateClient("identity");
IdentityClient.BaseAddress = new Uri(baseIdentityServerUri);
IdentityClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
}
var message = new TokenHttpRequestMessage(requestModel, AccessToken)
{
Method = method,
RequestUri = new Uri(string.Concat(Client.BaseAddress, path))
};
try
{
var response = await Client.SendAsync(message);
return await response.Content.ReadFromJsonAsync<TResult>();
}
catch (Exception e)
{
_logger.LogError(12334, e, "Failed to send to {0}.", message.RequestUri.ToString());
return default;
}
}
protected HttpClient Client { get; private set; }
protected HttpClient IdentityClient { get; private set; }
protected string AccessToken { get; private set; }
protected async Task<bool> HandleTokenStateAsync()
{
if (_nextAuthAttempt.HasValue && DateTime.UtcNow > _nextAuthAttempt.Value)
{
return false;
}
_nextAuthAttempt = null;
protected Task SendAsync(HttpMethod method, string path) =>
SendAsync<object, object>(method, path, null);
if (!string.IsNullOrWhiteSpace(AccessToken) && !TokenNeedsRefresh())
protected Task SendAsync<TRequest>(HttpMethod method, string path, TRequest body) =>
SendAsync<TRequest, object>(method, path, body);
protected async Task<TResult> SendAsync<TRequest, TResult>(HttpMethod method, string path, TRequest requestModel)
{
var tokenStateResponse = await HandleTokenStateAsync();
if (!tokenStateResponse)
{
return default;
}
var message = new TokenHttpRequestMessage(requestModel, AccessToken)
{
Method = method,
RequestUri = new Uri(string.Concat(Client.BaseAddress, path))
};
try
{
var response = await Client.SendAsync(message);
return await response.Content.ReadFromJsonAsync<TResult>();
}
catch (Exception e)
{
_logger.LogError(12334, e, "Failed to send to {0}.", message.RequestUri.ToString());
return default;
}
}
protected async Task<bool> HandleTokenStateAsync()
{
if (_nextAuthAttempt.HasValue && DateTime.UtcNow > _nextAuthAttempt.Value)
{
return false;
}
_nextAuthAttempt = null;
if (!string.IsNullOrWhiteSpace(AccessToken) && !TokenNeedsRefresh())
{
return true;
}
var requestMessage = new HttpRequestMessage
{
Method = HttpMethod.Post,
RequestUri = new Uri(string.Concat(IdentityClient.BaseAddress, "connect/token")),
Content = new FormUrlEncodedContent(new Dictionary<string, string>
{
{ "grant_type", "client_credentials" },
{ "scope", _identityScope },
{ "client_id", _identityClientId },
{ "client_secret", _identityClientSecret }
})
};
HttpResponseMessage response = null;
try
{
response = await IdentityClient.SendAsync(requestMessage);
}
catch (Exception e)
{
_logger.LogError(12339, e, "Unable to authenticate with identity server.");
}
if (response == null)
{
return false;
}
if (!response.IsSuccessStatusCode)
{
_logger.LogInformation("Unsuccessful token response with status code {StatusCode}", response.StatusCode);
if (response.StatusCode == HttpStatusCode.BadRequest)
{
_nextAuthAttempt = DateTime.UtcNow.AddDays(1);
}
if (_logger.IsEnabled(LogLevel.Debug))
{
var responseBody = await response.Content.ReadAsStringAsync();
_logger.LogDebug("Error response body:\n{ResponseBody}", responseBody);
}
return false;
}
using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync());
AccessToken = jsonDocument.RootElement.GetProperty("access_token").GetString();
return true;
}
var requestMessage = new HttpRequestMessage
protected class TokenHttpRequestMessage : HttpRequestMessage
{
Method = HttpMethod.Post,
RequestUri = new Uri(string.Concat(IdentityClient.BaseAddress, "connect/token")),
Content = new FormUrlEncodedContent(new Dictionary<string, string>
public TokenHttpRequestMessage(string token)
{
{ "grant_type", "client_credentials" },
{ "scope", _identityScope },
{ "client_id", _identityClientId },
{ "client_secret", _identityClientSecret }
})
};
HttpResponseMessage response = null;
try
{
response = await IdentityClient.SendAsync(requestMessage);
}
catch (Exception e)
{
_logger.LogError(12339, e, "Unable to authenticate with identity server.");
}
if (response == null)
{
return false;
}
if (!response.IsSuccessStatusCode)
{
_logger.LogInformation("Unsuccessful token response with status code {StatusCode}", response.StatusCode);
if (response.StatusCode == HttpStatusCode.BadRequest)
{
_nextAuthAttempt = DateTime.UtcNow.AddDays(1);
Headers.Add("Authorization", $"Bearer {token}");
}
if (_logger.IsEnabled(LogLevel.Debug))
public TokenHttpRequestMessage(object requestObject, string token)
: this(token)
{
var responseBody = await response.Content.ReadAsStringAsync();
_logger.LogDebug("Error response body:\n{ResponseBody}", responseBody);
}
return false;
}
using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync());
AccessToken = jsonDocument.RootElement.GetProperty("access_token").GetString();
return true;
}
protected class TokenHttpRequestMessage : HttpRequestMessage
{
public TokenHttpRequestMessage(string token)
{
Headers.Add("Authorization", $"Bearer {token}");
}
public TokenHttpRequestMessage(object requestObject, string token)
: this(token)
{
if (requestObject != null)
{
Content = JsonContent.Create(requestObject);
if (requestObject != null)
{
Content = JsonContent.Create(requestObject);
}
}
}
}
protected bool TokenNeedsRefresh(int minutes = 5)
{
var decoded = DecodeToken();
if (!decoded.RootElement.TryGetProperty("exp", out var expProp))
protected bool TokenNeedsRefresh(int minutes = 5)
{
throw new InvalidOperationException("No exp in token.");
var decoded = DecodeToken();
if (!decoded.RootElement.TryGetProperty("exp", out var expProp))
{
throw new InvalidOperationException("No exp in token.");
}
var expiration = CoreHelpers.FromEpocSeconds(expProp.GetInt64());
return DateTime.UtcNow.AddMinutes(-1 * minutes) > expiration;
}
var expiration = CoreHelpers.FromEpocSeconds(expProp.GetInt64());
return DateTime.UtcNow.AddMinutes(-1 * minutes) > expiration;
}
protected JsonDocument DecodeToken()
{
if (_decodedToken != null)
protected JsonDocument DecodeToken()
{
if (_decodedToken != null)
{
return _decodedToken;
}
if (AccessToken == null)
{
throw new InvalidOperationException($"{nameof(AccessToken)} not found.");
}
var parts = AccessToken.Split('.');
if (parts.Length != 3)
{
throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts");
}
var decodedBytes = CoreHelpers.Base64UrlDecode(parts[1]);
if (decodedBytes == null || decodedBytes.Length < 1)
{
throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts");
}
_decodedToken = JsonDocument.Parse(decodedBytes);
return _decodedToken;
}
if (AccessToken == null)
public void Dispose()
{
throw new InvalidOperationException($"{nameof(AccessToken)} not found.");
_decodedToken?.Dispose();
}
var parts = AccessToken.Split('.');
if (parts.Length != 3)
{
throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts");
}
var decodedBytes = CoreHelpers.Base64UrlDecode(parts[1]);
if (decodedBytes == null || decodedBytes.Length < 1)
{
throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts");
}
_decodedToken = JsonDocument.Parse(decodedBytes);
return _decodedToken;
}
public void Dispose()
{
_decodedToken?.Dispose();
}
}

View File

@ -1,19 +1,20 @@
using Bit.Core.Models.Mail;
namespace Bit.Core.Services;
public class BlockingMailEnqueuingService : IMailEnqueuingService
namespace Bit.Core.Services
{
public async Task EnqueueAsync(IMailQueueMessage message, Func<IMailQueueMessage, Task> fallback)
public class BlockingMailEnqueuingService : IMailEnqueuingService
{
await fallback(message);
}
public async Task EnqueueManyAsync(IEnumerable<IMailQueueMessage> messages, Func<IMailQueueMessage, Task> fallback)
{
foreach (var message in messages)
public async Task EnqueueAsync(IMailQueueMessage message, Func<IMailQueueMessage, Task> fallback)
{
await fallback(message);
}
public async Task EnqueueManyAsync(IEnumerable<IMailQueueMessage> messages, Func<IMailQueueMessage, Task> fallback)
{
foreach (var message in messages)
{
await fallback(message);
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -6,135 +6,136 @@ using Bit.Core.Models.Business;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
namespace Bit.Core.Services;
public class CollectionService : ICollectionService
namespace Bit.Core.Services
{
private readonly IEventService _eventService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly ICollectionRepository _collectionRepository;
private readonly IUserRepository _userRepository;
private readonly IMailService _mailService;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
public CollectionService(
IEventService eventService,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository,
IUserRepository userRepository,
IMailService mailService,
IReferenceEventService referenceEventService,
ICurrentContext currentContext)
public class CollectionService : ICollectionService
{
_eventService = eventService;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_collectionRepository = collectionRepository;
_userRepository = userRepository;
_mailService = mailService;
_referenceEventService = referenceEventService;
_currentContext = currentContext;
}
private readonly IEventService _eventService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly ICollectionRepository _collectionRepository;
private readonly IUserRepository _userRepository;
private readonly IMailService _mailService;
private readonly IReferenceEventService _referenceEventService;
private readonly ICurrentContext _currentContext;
public async Task SaveAsync(Collection collection, IEnumerable<SelectionReadOnly> groups = null,
Guid? assignUserId = null)
{
var org = await _organizationRepository.GetByIdAsync(collection.OrganizationId);
if (org == null)
public CollectionService(
IEventService eventService,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository,
IUserRepository userRepository,
IMailService mailService,
IReferenceEventService referenceEventService,
ICurrentContext currentContext)
{
throw new BadRequestException("Organization not found");
_eventService = eventService;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_collectionRepository = collectionRepository;
_userRepository = userRepository;
_mailService = mailService;
_referenceEventService = referenceEventService;
_currentContext = currentContext;
}
if (collection.Id == default(Guid))
public async Task SaveAsync(Collection collection, IEnumerable<SelectionReadOnly> groups = null,
Guid? assignUserId = null)
{
if (org.MaxCollections.HasValue)
var org = await _organizationRepository.GetByIdAsync(collection.OrganizationId);
if (org == null)
{
var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(org.Id);
if (org.MaxCollections.Value <= collectionCount)
{
throw new BadRequestException("You have reached the maximum number of collections " +
$"({org.MaxCollections.Value}) for this organization.");
}
throw new BadRequestException("Organization not found");
}
if (groups == null || !org.UseGroups)
if (collection.Id == default(Guid))
{
await _collectionRepository.CreateAsync(collection);
if (org.MaxCollections.HasValue)
{
var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(org.Id);
if (org.MaxCollections.Value <= collectionCount)
{
throw new BadRequestException("You have reached the maximum number of collections " +
$"({org.MaxCollections.Value}) for this organization.");
}
}
if (groups == null || !org.UseGroups)
{
await _collectionRepository.CreateAsync(collection);
}
else
{
await _collectionRepository.CreateAsync(collection, groups);
}
// Assign a user to the newly created collection.
if (assignUserId.HasValue)
{
var orgUser = await _organizationUserRepository.GetByOrganizationAsync(org.Id, assignUserId.Value);
if (orgUser != null && orgUser.Status == Enums.OrganizationUserStatusType.Confirmed)
{
await _collectionRepository.UpdateUsersAsync(collection.Id,
new List<SelectionReadOnly> {
new SelectionReadOnly { Id = orgUser.Id, ReadOnly = false } });
}
}
await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Created);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.CollectionCreated, org));
}
else
{
await _collectionRepository.CreateAsync(collection, groups);
}
// Assign a user to the newly created collection.
if (assignUserId.HasValue)
{
var orgUser = await _organizationUserRepository.GetByOrganizationAsync(org.Id, assignUserId.Value);
if (orgUser != null && orgUser.Status == Enums.OrganizationUserStatusType.Confirmed)
if (!org.UseGroups)
{
await _collectionRepository.UpdateUsersAsync(collection.Id,
new List<SelectionReadOnly> {
new SelectionReadOnly { Id = orgUser.Id, ReadOnly = false } });
await _collectionRepository.ReplaceAsync(collection);
}
else
{
await _collectionRepository.ReplaceAsync(collection, groups ?? new List<SelectionReadOnly>());
}
await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Updated);
}
}
public async Task DeleteAsync(Collection collection)
{
await _collectionRepository.DeleteAsync(collection);
await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Deleted);
}
public async Task DeleteUserAsync(Collection collection, Guid organizationUserId)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId);
if (orgUser == null || orgUser.OrganizationId != collection.OrganizationId)
{
throw new NotFoundException();
}
await _collectionRepository.DeleteUserAsync(collection.Id, organizationUserId);
await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_Updated);
}
public async Task<IEnumerable<Collection>> GetOrganizationCollections(Guid organizationId)
{
if (!await _currentContext.ViewAllCollections(organizationId) && !await _currentContext.ManageUsers(organizationId))
{
throw new NotFoundException();
}
await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Created);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.CollectionCreated, org));
}
else
{
if (!org.UseGroups)
IEnumerable<Collection> orgCollections;
if (await _currentContext.OrganizationAdmin(organizationId) || await _currentContext.ViewAllCollections(organizationId))
{
await _collectionRepository.ReplaceAsync(collection);
// Admins, Owners, Providers and Custom (with collection management permissions) can access all items even if not assigned to them
orgCollections = await _collectionRepository.GetManyByOrganizationIdAsync(organizationId);
}
else
{
await _collectionRepository.ReplaceAsync(collection, groups ?? new List<SelectionReadOnly>());
var collections = await _collectionRepository.GetManyByUserIdAsync(_currentContext.UserId.Value);
orgCollections = collections.Where(c => c.OrganizationId == organizationId);
}
await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Updated);
return orgCollections;
}
}
public async Task DeleteAsync(Collection collection)
{
await _collectionRepository.DeleteAsync(collection);
await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Deleted);
}
public async Task DeleteUserAsync(Collection collection, Guid organizationUserId)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId);
if (orgUser == null || orgUser.OrganizationId != collection.OrganizationId)
{
throw new NotFoundException();
}
await _collectionRepository.DeleteUserAsync(collection.Id, organizationUserId);
await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_Updated);
}
public async Task<IEnumerable<Collection>> GetOrganizationCollections(Guid organizationId)
{
if (!await _currentContext.ViewAllCollections(organizationId) && !await _currentContext.ManageUsers(organizationId))
{
throw new NotFoundException();
}
IEnumerable<Collection> orgCollections;
if (await _currentContext.OrganizationAdmin(organizationId) || await _currentContext.ViewAllCollections(organizationId))
{
// Admins, Owners, Providers and Custom (with collection management permissions) can access all items even if not assigned to them
orgCollections = await _collectionRepository.GetManyByOrganizationIdAsync(organizationId);
}
else
{
var collections = await _collectionRepository.GetManyByUserIdAsync(_currentContext.UserId.Value);
orgCollections = collections.Where(c => c.OrganizationId == organizationId);
}
return orgCollections;
}
}

View File

@ -1,46 +1,47 @@
using Bit.Core.Entities;
using Bit.Core.Repositories;
namespace Bit.Core.Services;
public class DeviceService : IDeviceService
namespace Bit.Core.Services
{
private readonly IDeviceRepository _deviceRepository;
private readonly IPushRegistrationService _pushRegistrationService;
public DeviceService(
IDeviceRepository deviceRepository,
IPushRegistrationService pushRegistrationService)
public class DeviceService : IDeviceService
{
_deviceRepository = deviceRepository;
_pushRegistrationService = pushRegistrationService;
}
private readonly IDeviceRepository _deviceRepository;
private readonly IPushRegistrationService _pushRegistrationService;
public async Task SaveAsync(Device device)
{
if (device.Id == default(Guid))
public DeviceService(
IDeviceRepository deviceRepository,
IPushRegistrationService pushRegistrationService)
{
await _deviceRepository.CreateAsync(device);
}
else
{
device.RevisionDate = DateTime.UtcNow;
await _deviceRepository.ReplaceAsync(device);
_deviceRepository = deviceRepository;
_pushRegistrationService = pushRegistrationService;
}
await _pushRegistrationService.CreateOrUpdateRegistrationAsync(device.PushToken, device.Id.ToString(),
device.UserId.ToString(), device.Identifier, device.Type);
}
public async Task SaveAsync(Device device)
{
if (device.Id == default(Guid))
{
await _deviceRepository.CreateAsync(device);
}
else
{
device.RevisionDate = DateTime.UtcNow;
await _deviceRepository.ReplaceAsync(device);
}
public async Task ClearTokenAsync(Device device)
{
await _deviceRepository.ClearPushTokenAsync(device.Id);
await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString());
}
await _pushRegistrationService.CreateOrUpdateRegistrationAsync(device.PushToken, device.Id.ToString(),
device.UserId.ToString(), device.Identifier, device.Type);
}
public async Task DeleteAsync(Device device)
{
await _deviceRepository.DeleteAsync(device);
await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString());
public async Task ClearTokenAsync(Device device)
{
await _deviceRepository.ClearPushTokenAsync(device.Id);
await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString());
}
public async Task DeleteAsync(Device device)
{
await _deviceRepository.DeleteAsync(device);
await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString());
}
}
}

View File

@ -9,415 +9,416 @@ using Bit.Core.Settings;
using Bit.Core.Tokens;
using Microsoft.AspNetCore.Identity;
namespace Bit.Core.Services;
public class EmergencyAccessService : IEmergencyAccessService
namespace Bit.Core.Services
{
private readonly IEmergencyAccessRepository _emergencyAccessRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IUserRepository _userRepository;
private readonly ICipherRepository _cipherRepository;
private readonly IPolicyRepository _policyRepository;
private readonly ICipherService _cipherService;
private readonly IMailService _mailService;
private readonly IUserService _userService;
private readonly GlobalSettings _globalSettings;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IOrganizationService _organizationService;
private readonly IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> _dataProtectorTokenizer;
public EmergencyAccessService(
IEmergencyAccessRepository emergencyAccessRepository,
IOrganizationUserRepository organizationUserRepository,
IUserRepository userRepository,
ICipherRepository cipherRepository,
IPolicyRepository policyRepository,
ICipherService cipherService,
IMailService mailService,
IUserService userService,
IPasswordHasher<User> passwordHasher,
GlobalSettings globalSettings,
IOrganizationService organizationService,
IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> dataProtectorTokenizer)
public class EmergencyAccessService : IEmergencyAccessService
{
_emergencyAccessRepository = emergencyAccessRepository;
_organizationUserRepository = organizationUserRepository;
_userRepository = userRepository;
_cipherRepository = cipherRepository;
_policyRepository = policyRepository;
_cipherService = cipherService;
_mailService = mailService;
_userService = userService;
_passwordHasher = passwordHasher;
_globalSettings = globalSettings;
_organizationService = organizationService;
_dataProtectorTokenizer = dataProtectorTokenizer;
}
private readonly IEmergencyAccessRepository _emergencyAccessRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IUserRepository _userRepository;
private readonly ICipherRepository _cipherRepository;
private readonly IPolicyRepository _policyRepository;
private readonly ICipherService _cipherService;
private readonly IMailService _mailService;
private readonly IUserService _userService;
private readonly GlobalSettings _globalSettings;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IOrganizationService _organizationService;
private readonly IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> _dataProtectorTokenizer;
public async Task<EmergencyAccess> InviteAsync(User invitingUser, string email, EmergencyAccessType type, int waitTime)
{
if (!await _userService.CanAccessPremium(invitingUser))
public EmergencyAccessService(
IEmergencyAccessRepository emergencyAccessRepository,
IOrganizationUserRepository organizationUserRepository,
IUserRepository userRepository,
ICipherRepository cipherRepository,
IPolicyRepository policyRepository,
ICipherService cipherService,
IMailService mailService,
IUserService userService,
IPasswordHasher<User> passwordHasher,
GlobalSettings globalSettings,
IOrganizationService organizationService,
IDataProtectorTokenFactory<EmergencyAccessInviteTokenable> dataProtectorTokenizer)
{
throw new BadRequestException("Not a premium user.");
_emergencyAccessRepository = emergencyAccessRepository;
_organizationUserRepository = organizationUserRepository;
_userRepository = userRepository;
_cipherRepository = cipherRepository;
_policyRepository = policyRepository;
_cipherService = cipherService;
_mailService = mailService;
_userService = userService;
_passwordHasher = passwordHasher;
_globalSettings = globalSettings;
_organizationService = organizationService;
_dataProtectorTokenizer = dataProtectorTokenizer;
}
if (type == EmergencyAccessType.Takeover && invitingUser.UsesKeyConnector)
public async Task<EmergencyAccess> InviteAsync(User invitingUser, string email, EmergencyAccessType type, int waitTime)
{
throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector.");
}
if (!await _userService.CanAccessPremium(invitingUser))
{
throw new BadRequestException("Not a premium user.");
}
var emergencyAccess = new EmergencyAccess
{
GrantorId = invitingUser.Id,
Email = email.ToLowerInvariant(),
Status = EmergencyAccessStatusType.Invited,
Type = type,
WaitTimeDays = waitTime,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
};
await _emergencyAccessRepository.CreateAsync(emergencyAccess);
await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser));
return emergencyAccess;
}
public async Task<EmergencyAccessDetails> GetAsync(Guid emergencyAccessId, Guid userId)
{
var emergencyAccess = await _emergencyAccessRepository.GetDetailsByIdGrantorIdAsync(emergencyAccessId, userId);
if (emergencyAccess == null)
{
throw new BadRequestException("Emergency Access not valid.");
}
return emergencyAccess;
}
public async Task ResendInviteAsync(User invitingUser, Guid emergencyAccessId)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
if (emergencyAccess == null || emergencyAccess.GrantorId != invitingUser.Id ||
emergencyAccess.Status != EmergencyAccessStatusType.Invited)
{
throw new BadRequestException("Emergency Access not valid.");
}
await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser));
}
public async Task<EmergencyAccess> AcceptUserAsync(Guid emergencyAccessId, User user, string token, IUserService userService)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
if (emergencyAccess == null)
{
throw new BadRequestException("Emergency Access not valid.");
}
if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email))
{
throw new BadRequestException("Invalid token.");
}
if (emergencyAccess.Status == EmergencyAccessStatusType.Accepted)
{
throw new BadRequestException("Invitation already accepted. You will receive an email when the grantor confirms you as an emergency access contact.");
}
else if (emergencyAccess.Status != EmergencyAccessStatusType.Invited)
{
throw new BadRequestException("Invitation already accepted.");
}
if (string.IsNullOrWhiteSpace(emergencyAccess.Email) ||
!emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
{
throw new BadRequestException("User email does not match invite.");
}
var granteeEmail = emergencyAccess.Email;
emergencyAccess.Status = EmergencyAccessStatusType.Accepted;
emergencyAccess.GranteeId = user.Id;
emergencyAccess.Email = null;
var grantor = await userService.GetUserByIdAsync(emergencyAccess.GrantorId);
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
await _mailService.SendEmergencyAccessAcceptedEmailAsync(granteeEmail, grantor.Email);
return emergencyAccess;
}
public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId))
{
throw new BadRequestException("Emergency Access not valid.");
}
await _emergencyAccessRepository.DeleteAsync(emergencyAccess);
}
public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId);
if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted ||
emergencyAccess.GrantorId != confirmingUserId)
{
throw new BadRequestException("Emergency Access not valid.");
}
var grantor = await _userRepository.GetByIdAsync(confirmingUserId);
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
{
throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector.");
}
var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value);
emergencyAccess.Status = EmergencyAccessStatusType.Confirmed;
emergencyAccess.KeyEncrypted = key;
emergencyAccess.Email = null;
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
await _mailService.SendEmergencyAccessConfirmedEmailAsync(NameOrEmail(grantor), grantee.Email);
return emergencyAccess;
}
public async Task SaveAsync(EmergencyAccess emergencyAccess, User savingUser)
{
if (!await _userService.CanAccessPremium(savingUser))
{
throw new BadRequestException("Not a premium user.");
}
if (emergencyAccess.GrantorId != savingUser.Id)
{
throw new BadRequestException("Emergency Access not valid.");
}
if (emergencyAccess.Type == EmergencyAccessType.Takeover)
{
var grantor = await _userService.GetUserByIdAsync(emergencyAccess.GrantorId);
if (grantor.UsesKeyConnector)
if (type == EmergencyAccessType.Takeover && invitingUser.UsesKeyConnector)
{
throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector.");
}
}
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
}
public async Task InitiateAsync(Guid id, User initiatingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id ||
emergencyAccess.Status != EmergencyAccessStatusType.Confirmed)
{
throw new BadRequestException("Emergency Access not valid.");
}
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
{
throw new BadRequestException("You cannot takeover an account that is using Key Connector.");
}
var now = DateTime.UtcNow;
emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated;
emergencyAccess.RevisionDate = now;
emergencyAccess.RecoveryInitiatedDate = now;
emergencyAccess.LastNotificationDate = now;
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
await _mailService.SendEmergencyAccessRecoveryInitiated(emergencyAccess, NameOrEmail(initiatingUser), grantor.Email);
}
public async Task ApproveAsync(Guid id, User approvingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (emergencyAccess == null || emergencyAccess.GrantorId != approvingUser.Id ||
emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated)
{
throw new BadRequestException("Emergency Access not valid.");
}
emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved;
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value);
await _mailService.SendEmergencyAccessRecoveryApproved(emergencyAccess, NameOrEmail(approvingUser), grantee.Email);
}
public async Task RejectAsync(Guid id, User rejectingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (emergencyAccess == null || emergencyAccess.GrantorId != rejectingUser.Id ||
(emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated &&
emergencyAccess.Status != EmergencyAccessStatusType.RecoveryApproved))
{
throw new BadRequestException("Emergency Access not valid.");
}
emergencyAccess.Status = EmergencyAccessStatusType.Confirmed;
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value);
await _mailService.SendEmergencyAccessRecoveryRejected(emergencyAccess, NameOrEmail(rejectingUser), grantee.Email);
}
public async Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
{
throw new BadRequestException("Emergency Access not valid.");
}
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id);
var isOrganizationOwner = grantorOrganizations.Any<OrganizationUser>(organization => organization.Type == OrganizationUserType.Owner);
var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null;
return policies;
}
public async Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User requestingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
{
throw new BadRequestException("Emergency Access not valid.");
}
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
{
throw new BadRequestException("You cannot takeover an account that is using Key Connector.");
}
return (emergencyAccess, grantor);
}
public async Task PasswordAsync(Guid id, User requestingUser, string newMasterPasswordHash, string key)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
{
throw new BadRequestException("Emergency Access not valid.");
}
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
grantor.MasterPassword = _passwordHasher.HashPassword(grantor, newMasterPasswordHash);
grantor.Key = key;
// Disable TwoFactor providers since they will otherwise block logins
grantor.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>());
grantor.UnknownDeviceVerificationEnabled = false;
await _userRepository.ReplaceAsync(grantor);
// Remove grantor from all organizations unless Owner
var orgUser = await _organizationUserRepository.GetManyByUserAsync(grantor.Id);
foreach (var o in orgUser)
{
if (o.Type != OrganizationUserType.Owner)
var emergencyAccess = new EmergencyAccess
{
await _organizationService.DeleteUserAsync(o.OrganizationId, grantor.Id);
GrantorId = invitingUser.Id,
Email = email.ToLowerInvariant(),
Status = EmergencyAccessStatusType.Invited,
Type = type,
WaitTimeDays = waitTime,
CreationDate = DateTime.UtcNow,
RevisionDate = DateTime.UtcNow,
};
await _emergencyAccessRepository.CreateAsync(emergencyAccess);
await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser));
return emergencyAccess;
}
public async Task<EmergencyAccessDetails> GetAsync(Guid emergencyAccessId, Guid userId)
{
var emergencyAccess = await _emergencyAccessRepository.GetDetailsByIdGrantorIdAsync(emergencyAccessId, userId);
if (emergencyAccess == null)
{
throw new BadRequestException("Emergency Access not valid.");
}
return emergencyAccess;
}
public async Task ResendInviteAsync(User invitingUser, Guid emergencyAccessId)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
if (emergencyAccess == null || emergencyAccess.GrantorId != invitingUser.Id ||
emergencyAccess.Status != EmergencyAccessStatusType.Invited)
{
throw new BadRequestException("Emergency Access not valid.");
}
await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser));
}
public async Task<EmergencyAccess> AcceptUserAsync(Guid emergencyAccessId, User user, string token, IUserService userService)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
if (emergencyAccess == null)
{
throw new BadRequestException("Emergency Access not valid.");
}
if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email))
{
throw new BadRequestException("Invalid token.");
}
if (emergencyAccess.Status == EmergencyAccessStatusType.Accepted)
{
throw new BadRequestException("Invitation already accepted. You will receive an email when the grantor confirms you as an emergency access contact.");
}
else if (emergencyAccess.Status != EmergencyAccessStatusType.Invited)
{
throw new BadRequestException("Invitation already accepted.");
}
if (string.IsNullOrWhiteSpace(emergencyAccess.Email) ||
!emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase))
{
throw new BadRequestException("User email does not match invite.");
}
var granteeEmail = emergencyAccess.Email;
emergencyAccess.Status = EmergencyAccessStatusType.Accepted;
emergencyAccess.GranteeId = user.Id;
emergencyAccess.Email = null;
var grantor = await userService.GetUserByIdAsync(emergencyAccess.GrantorId);
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
await _mailService.SendEmergencyAccessAcceptedEmailAsync(granteeEmail, grantor.Email);
return emergencyAccess;
}
public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId);
if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId))
{
throw new BadRequestException("Emergency Access not valid.");
}
await _emergencyAccessRepository.DeleteAsync(emergencyAccess);
}
public async Task<EmergencyAccess> ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId);
if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted ||
emergencyAccess.GrantorId != confirmingUserId)
{
throw new BadRequestException("Emergency Access not valid.");
}
var grantor = await _userRepository.GetByIdAsync(confirmingUserId);
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
{
throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector.");
}
var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value);
emergencyAccess.Status = EmergencyAccessStatusType.Confirmed;
emergencyAccess.KeyEncrypted = key;
emergencyAccess.Email = null;
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
await _mailService.SendEmergencyAccessConfirmedEmailAsync(NameOrEmail(grantor), grantee.Email);
return emergencyAccess;
}
public async Task SaveAsync(EmergencyAccess emergencyAccess, User savingUser)
{
if (!await _userService.CanAccessPremium(savingUser))
{
throw new BadRequestException("Not a premium user.");
}
if (emergencyAccess.GrantorId != savingUser.Id)
{
throw new BadRequestException("Emergency Access not valid.");
}
if (emergencyAccess.Type == EmergencyAccessType.Takeover)
{
var grantor = await _userService.GetUserByIdAsync(emergencyAccess.GrantorId);
if (grantor.UsesKeyConnector)
{
throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector.");
}
}
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
}
public async Task InitiateAsync(Guid id, User initiatingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id ||
emergencyAccess.Status != EmergencyAccessStatusType.Confirmed)
{
throw new BadRequestException("Emergency Access not valid.");
}
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
{
throw new BadRequestException("You cannot takeover an account that is using Key Connector.");
}
var now = DateTime.UtcNow;
emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated;
emergencyAccess.RevisionDate = now;
emergencyAccess.RecoveryInitiatedDate = now;
emergencyAccess.LastNotificationDate = now;
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
await _mailService.SendEmergencyAccessRecoveryInitiated(emergencyAccess, NameOrEmail(initiatingUser), grantor.Email);
}
public async Task ApproveAsync(Guid id, User approvingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (emergencyAccess == null || emergencyAccess.GrantorId != approvingUser.Id ||
emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated)
{
throw new BadRequestException("Emergency Access not valid.");
}
emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved;
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value);
await _mailService.SendEmergencyAccessRecoveryApproved(emergencyAccess, NameOrEmail(approvingUser), grantee.Email);
}
public async Task RejectAsync(Guid id, User rejectingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (emergencyAccess == null || emergencyAccess.GrantorId != rejectingUser.Id ||
(emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated &&
emergencyAccess.Status != EmergencyAccessStatusType.RecoveryApproved))
{
throw new BadRequestException("Emergency Access not valid.");
}
emergencyAccess.Status = EmergencyAccessStatusType.Confirmed;
await _emergencyAccessRepository.ReplaceAsync(emergencyAccess);
var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value);
await _mailService.SendEmergencyAccessRecoveryRejected(emergencyAccess, NameOrEmail(rejectingUser), grantee.Email);
}
public async Task<ICollection<Policy>> GetPoliciesAsync(Guid id, User requestingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
{
throw new BadRequestException("Emergency Access not valid.");
}
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id);
var isOrganizationOwner = grantorOrganizations.Any<OrganizationUser>(organization => organization.Type == OrganizationUserType.Owner);
var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null;
return policies;
}
public async Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User requestingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
{
throw new BadRequestException("Emergency Access not valid.");
}
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector)
{
throw new BadRequestException("You cannot takeover an account that is using Key Connector.");
}
return (emergencyAccess, grantor);
}
public async Task PasswordAsync(Guid id, User requestingUser, string newMasterPasswordHash, string key)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover))
{
throw new BadRequestException("Emergency Access not valid.");
}
var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId);
grantor.MasterPassword = _passwordHasher.HashPassword(grantor, newMasterPasswordHash);
grantor.Key = key;
// Disable TwoFactor providers since they will otherwise block logins
grantor.SetTwoFactorProviders(new Dictionary<TwoFactorProviderType, TwoFactorProvider>());
grantor.UnknownDeviceVerificationEnabled = false;
await _userRepository.ReplaceAsync(grantor);
// Remove grantor from all organizations unless Owner
var orgUser = await _organizationUserRepository.GetManyByUserAsync(grantor.Id);
foreach (var o in orgUser)
{
if (o.Type != OrganizationUserType.Owner)
{
await _organizationService.DeleteUserAsync(o.OrganizationId, grantor.Id);
}
}
}
}
public async Task SendNotificationsAsync()
{
var toNotify = await _emergencyAccessRepository.GetManyToNotifyAsync();
foreach (var notify in toNotify)
public async Task SendNotificationsAsync()
{
var ea = notify.ToEmergencyAccess();
ea.LastNotificationDate = DateTime.UtcNow;
await _emergencyAccessRepository.ReplaceAsync(ea);
var toNotify = await _emergencyAccessRepository.GetManyToNotifyAsync();
var granteeNameOrEmail = string.IsNullOrWhiteSpace(notify.GranteeName) ? notify.GranteeEmail : notify.GranteeName;
foreach (var notify in toNotify)
{
var ea = notify.ToEmergencyAccess();
ea.LastNotificationDate = DateTime.UtcNow;
await _emergencyAccessRepository.ReplaceAsync(ea);
await _mailService.SendEmergencyAccessRecoveryReminder(ea, granteeNameOrEmail, notify.GrantorEmail);
}
}
var granteeNameOrEmail = string.IsNullOrWhiteSpace(notify.GranteeName) ? notify.GranteeEmail : notify.GranteeName;
public async Task HandleTimedOutRequestsAsync()
{
var expired = await _emergencyAccessRepository.GetExpiredRecoveriesAsync();
foreach (var details in expired)
{
var ea = details.ToEmergencyAccess();
ea.Status = EmergencyAccessStatusType.RecoveryApproved;
await _emergencyAccessRepository.ReplaceAsync(ea);
var grantorNameOrEmail = string.IsNullOrWhiteSpace(details.GrantorName) ? details.GrantorEmail : details.GrantorName;
var granteeNameOrEmail = string.IsNullOrWhiteSpace(details.GranteeName) ? details.GranteeEmail : details.GranteeName;
await _mailService.SendEmergencyAccessRecoveryApproved(ea, grantorNameOrEmail, details.GranteeEmail);
await _mailService.SendEmergencyAccessRecoveryTimedOut(ea, granteeNameOrEmail, details.GrantorEmail);
}
}
public async Task<EmergencyAccessViewData> ViewAsync(Guid id, User requestingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View))
{
throw new BadRequestException("Emergency Access not valid.");
await _mailService.SendEmergencyAccessRecoveryReminder(ea, granteeNameOrEmail, notify.GrantorEmail);
}
}
var ciphers = await _cipherRepository.GetManyByUserIdAsync(emergencyAccess.GrantorId, false);
return new EmergencyAccessViewData
public async Task HandleTimedOutRequestsAsync()
{
EmergencyAccess = emergencyAccess,
Ciphers = ciphers,
};
}
var expired = await _emergencyAccessRepository.GetExpiredRecoveriesAsync();
public async Task<AttachmentResponseData> GetAttachmentDownloadAsync(Guid id, Guid cipherId, string attachmentId, User requestingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
foreach (var details in expired)
{
var ea = details.ToEmergencyAccess();
ea.Status = EmergencyAccessStatusType.RecoveryApproved;
await _emergencyAccessRepository.ReplaceAsync(ea);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View))
{
throw new BadRequestException("Emergency Access not valid.");
var grantorNameOrEmail = string.IsNullOrWhiteSpace(details.GrantorName) ? details.GrantorEmail : details.GrantorName;
var granteeNameOrEmail = string.IsNullOrWhiteSpace(details.GranteeName) ? details.GranteeEmail : details.GranteeName;
await _mailService.SendEmergencyAccessRecoveryApproved(ea, grantorNameOrEmail, details.GranteeEmail);
await _mailService.SendEmergencyAccessRecoveryTimedOut(ea, granteeNameOrEmail, details.GrantorEmail);
}
}
var cipher = await _cipherRepository.GetByIdAsync(cipherId, emergencyAccess.GrantorId);
return await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId);
}
public async Task<EmergencyAccessViewData> ViewAsync(Guid id, User requestingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
private async Task SendInviteAsync(EmergencyAccess emergencyAccess, string invitingUsersName)
{
var token = _dataProtectorTokenizer.Protect(new EmergencyAccessInviteTokenable(emergencyAccess, _globalSettings.OrganizationInviteExpirationHours));
await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token);
}
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View))
{
throw new BadRequestException("Emergency Access not valid.");
}
private string NameOrEmail(User user)
{
return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name;
}
var ciphers = await _cipherRepository.GetManyByUserIdAsync(emergencyAccess.GrantorId, false);
private bool IsValidRequest(EmergencyAccess availibleAccess, User requestingUser, EmergencyAccessType requestedAccessType)
{
return availibleAccess != null &&
availibleAccess.GranteeId == requestingUser.Id &&
availibleAccess.Status == EmergencyAccessStatusType.RecoveryApproved &&
availibleAccess.Type == requestedAccessType;
return new EmergencyAccessViewData
{
EmergencyAccess = emergencyAccess,
Ciphers = ciphers,
};
}
public async Task<AttachmentResponseData> GetAttachmentDownloadAsync(Guid id, Guid cipherId, string attachmentId, User requestingUser)
{
var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id);
if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View))
{
throw new BadRequestException("Emergency Access not valid.");
}
var cipher = await _cipherRepository.GetByIdAsync(cipherId, emergencyAccess.GrantorId);
return await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId);
}
private async Task SendInviteAsync(EmergencyAccess emergencyAccess, string invitingUsersName)
{
var token = _dataProtectorTokenizer.Protect(new EmergencyAccessInviteTokenable(emergencyAccess, _globalSettings.OrganizationInviteExpirationHours));
await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token);
}
private string NameOrEmail(User user)
{
return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name;
}
private bool IsValidRequest(EmergencyAccess availibleAccess, User requestingUser, EmergencyAccessType requestedAccessType)
{
return availibleAccess != null &&
availibleAccess.GranteeId == requestingUser.Id &&
availibleAccess.Status == EmergencyAccessStatusType.RecoveryApproved &&
availibleAccess.Type == requestedAccessType;
}
}
}

View File

@ -7,321 +7,322 @@ using Bit.Core.Models.Data.Organizations;
using Bit.Core.Repositories;
using Bit.Core.Settings;
namespace Bit.Core.Services;
public class EventService : IEventService
namespace Bit.Core.Services
{
private readonly IEventWriteService _eventWriteService;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IProviderUserRepository _providerUserRepository;
private readonly IApplicationCacheService _applicationCacheService;
private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings;
public EventService(
IEventWriteService eventWriteService,
IOrganizationUserRepository organizationUserRepository,
IProviderUserRepository providerUserRepository,
IApplicationCacheService applicationCacheService,
ICurrentContext currentContext,
GlobalSettings globalSettings)
public class EventService : IEventService
{
_eventWriteService = eventWriteService;
_organizationUserRepository = organizationUserRepository;
_providerUserRepository = providerUserRepository;
_applicationCacheService = applicationCacheService;
_currentContext = currentContext;
_globalSettings = globalSettings;
}
private readonly IEventWriteService _eventWriteService;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IProviderUserRepository _providerUserRepository;
private readonly IApplicationCacheService _applicationCacheService;
private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings;
public async Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null)
{
var events = new List<IEvent>
public EventService(
IEventWriteService eventWriteService,
IOrganizationUserRepository organizationUserRepository,
IProviderUserRepository providerUserRepository,
IApplicationCacheService applicationCacheService,
ICurrentContext currentContext,
GlobalSettings globalSettings)
{
new EventMessage(_currentContext)
_eventWriteService = eventWriteService;
_organizationUserRepository = organizationUserRepository;
_providerUserRepository = providerUserRepository;
_applicationCacheService = applicationCacheService;
_currentContext = currentContext;
_globalSettings = globalSettings;
}
public async Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null)
{
var events = new List<IEvent>
{
UserId = userId,
ActingUserId = userId,
Type = type,
Date = date.GetValueOrDefault(DateTime.UtcNow)
new EventMessage(_currentContext)
{
UserId = userId,
ActingUserId = userId,
Type = type,
Date = date.GetValueOrDefault(DateTime.UtcNow)
}
};
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, userId);
var orgEvents = orgs.Where(o => CanUseEvents(orgAbilities, o.Id))
.Select(o => new EventMessage(_currentContext)
{
OrganizationId = o.Id,
UserId = userId,
ActingUserId = userId,
Type = type,
Date = DateTime.UtcNow
});
var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync();
var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, userId);
var providerEvents = providers.Where(o => CanUseProviderEvents(providerAbilities, o.Id))
.Select(p => new EventMessage(_currentContext)
{
ProviderId = p.Id,
UserId = userId,
ActingUserId = userId,
Type = type,
Date = DateTime.UtcNow
});
if (orgEvents.Any() || providerEvents.Any())
{
events.AddRange(orgEvents);
events.AddRange(providerEvents);
await _eventWriteService.CreateManyAsync(events);
}
};
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, userId);
var orgEvents = orgs.Where(o => CanUseEvents(orgAbilities, o.Id))
.Select(o => new EventMessage(_currentContext)
else
{
OrganizationId = o.Id,
UserId = userId,
ActingUserId = userId,
Type = type,
Date = DateTime.UtcNow
});
var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync();
var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, userId);
var providerEvents = providers.Where(o => CanUseProviderEvents(providerAbilities, o.Id))
.Select(p => new EventMessage(_currentContext)
{
ProviderId = p.Id,
UserId = userId,
ActingUserId = userId,
Type = type,
Date = DateTime.UtcNow
});
if (orgEvents.Any() || providerEvents.Any())
{
events.AddRange(orgEvents);
events.AddRange(providerEvents);
await _eventWriteService.CreateManyAsync(events);
await _eventWriteService.CreateAsync(events.First());
}
}
else
{
await _eventWriteService.CreateAsync(events.First());
}
}
public async Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null)
{
var e = await BuildCipherEventMessageAsync(cipher, type, date);
if (e != null)
public async Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null)
{
await _eventWriteService.CreateAsync(e);
}
}
public async Task LogCipherEventsAsync(IEnumerable<Tuple<Cipher, EventType, DateTime?>> events)
{
var cipherEvents = new List<IEvent>();
foreach (var ev in events)
{
var e = await BuildCipherEventMessageAsync(ev.Item1, ev.Item2, ev.Item3);
var e = await BuildCipherEventMessageAsync(cipher, type, date);
if (e != null)
{
cipherEvents.Add(e);
await _eventWriteService.CreateAsync(e);
}
}
await _eventWriteService.CreateManyAsync(cipherEvents);
}
private async Task<EventMessage> BuildCipherEventMessageAsync(Cipher cipher, EventType type, DateTime? date = null)
{
// Only logging organization cipher events for now.
if (!cipher.OrganizationId.HasValue || (!_currentContext?.UserId.HasValue ?? true))
public async Task LogCipherEventsAsync(IEnumerable<Tuple<Cipher, EventType, DateTime?>> events)
{
return null;
var cipherEvents = new List<IEvent>();
foreach (var ev in events)
{
var e = await BuildCipherEventMessageAsync(ev.Item1, ev.Item2, ev.Item3);
if (e != null)
{
cipherEvents.Add(e);
}
}
await _eventWriteService.CreateManyAsync(cipherEvents);
}
if (cipher.OrganizationId.HasValue)
private async Task<EventMessage> BuildCipherEventMessageAsync(Cipher cipher, EventType type, DateTime? date = null)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
if (!CanUseEvents(orgAbilities, cipher.OrganizationId.Value))
// Only logging organization cipher events for now.
if (!cipher.OrganizationId.HasValue || (!_currentContext?.UserId.HasValue ?? true))
{
return null;
}
}
return new EventMessage(_currentContext)
{
OrganizationId = cipher.OrganizationId,
UserId = cipher.OrganizationId.HasValue ? null : cipher.UserId,
CipherId = cipher.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
ProviderId = await GetProviderIdAsync(cipher.OrganizationId),
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
}
public async Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
if (!CanUseEvents(orgAbilities, collection.OrganizationId))
{
return;
}
var e = new EventMessage(_currentContext)
{
OrganizationId = collection.OrganizationId,
CollectionId = collection.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
ProviderId = await GetProviderIdAsync(collection.OrganizationId),
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
await _eventWriteService.CreateAsync(e);
}
public async Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
if (!CanUseEvents(orgAbilities, group.OrganizationId))
{
return;
}
var e = new EventMessage(_currentContext)
{
OrganizationId = group.OrganizationId,
GroupId = group.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
ProviderId = await GetProviderIdAsync(@group.OrganizationId),
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
await _eventWriteService.CreateAsync(e);
}
public async Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
if (!CanUseEvents(orgAbilities, policy.OrganizationId))
{
return;
}
var e = new EventMessage(_currentContext)
{
OrganizationId = policy.OrganizationId,
PolicyId = policy.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
ProviderId = await GetProviderIdAsync(policy.OrganizationId),
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
await _eventWriteService.CreateAsync(e);
}
public async Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type,
DateTime? date = null) =>
await LogOrganizationUserEventsAsync(new[] { (organizationUser, type, date) });
public async Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var eventMessages = new List<IEvent>();
foreach (var (organizationUser, type, date) in events)
{
if (!CanUseEvents(orgAbilities, organizationUser.OrganizationId))
if (cipher.OrganizationId.HasValue)
{
continue;
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
if (!CanUseEvents(orgAbilities, cipher.OrganizationId.Value))
{
return null;
}
}
eventMessages.Add(new EventMessage(_currentContext)
return new EventMessage(_currentContext)
{
OrganizationId = organizationUser.OrganizationId,
UserId = organizationUser.UserId,
OrganizationUserId = organizationUser.Id,
ProviderId = await GetProviderIdAsync(organizationUser.OrganizationId),
OrganizationId = cipher.OrganizationId,
UserId = cipher.OrganizationId.HasValue ? null : cipher.UserId,
CipherId = cipher.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
ProviderId = await GetProviderIdAsync(cipher.OrganizationId),
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
}
public async Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
if (!CanUseEvents(orgAbilities, collection.OrganizationId))
{
return;
}
var e = new EventMessage(_currentContext)
{
OrganizationId = collection.OrganizationId,
CollectionId = collection.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
ProviderId = await GetProviderIdAsync(collection.OrganizationId),
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
await _eventWriteService.CreateAsync(e);
}
public async Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
if (!CanUseEvents(orgAbilities, group.OrganizationId))
{
return;
}
var e = new EventMessage(_currentContext)
{
OrganizationId = group.OrganizationId,
GroupId = group.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
ProviderId = await GetProviderIdAsync(@group.OrganizationId),
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
await _eventWriteService.CreateAsync(e);
}
public async Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
if (!CanUseEvents(orgAbilities, policy.OrganizationId))
{
return;
}
var e = new EventMessage(_currentContext)
{
OrganizationId = policy.OrganizationId,
PolicyId = policy.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
ProviderId = await GetProviderIdAsync(policy.OrganizationId),
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
await _eventWriteService.CreateAsync(e);
}
public async Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type,
DateTime? date = null) =>
await LogOrganizationUserEventsAsync(new[] { (organizationUser, type, date) });
public async Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events)
{
var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync();
var eventMessages = new List<IEvent>();
foreach (var (organizationUser, type, date) in events)
{
if (!CanUseEvents(orgAbilities, organizationUser.OrganizationId))
{
continue;
}
eventMessages.Add(new EventMessage(_currentContext)
{
OrganizationId = organizationUser.OrganizationId,
UserId = organizationUser.UserId,
OrganizationUserId = organizationUser.Id,
ProviderId = await GetProviderIdAsync(organizationUser.OrganizationId),
Type = type,
ActingUserId = _currentContext?.UserId,
Date = date.GetValueOrDefault(DateTime.UtcNow)
});
}
await _eventWriteService.CreateManyAsync(eventMessages);
}
public async Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null)
{
if (!organization.Enabled || !organization.UseEvents)
{
return;
}
var e = new EventMessage(_currentContext)
{
OrganizationId = organization.Id,
ProviderId = await GetProviderIdAsync(organization.Id),
Type = type,
ActingUserId = _currentContext?.UserId,
Date = date.GetValueOrDefault(DateTime.UtcNow),
InstallationId = GetInstallationId(),
};
await _eventWriteService.CreateAsync(e);
}
public async Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null)
{
await LogProviderUsersEventAsync(new[] { (providerUser, type, date) });
}
public async Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events)
{
var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync();
var eventMessages = new List<IEvent>();
foreach (var (providerUser, type, date) in events)
{
if (!CanUseProviderEvents(providerAbilities, providerUser.ProviderId))
{
continue;
}
eventMessages.Add(new EventMessage(_currentContext)
{
ProviderId = providerUser.ProviderId,
UserId = providerUser.UserId,
ProviderUserId = providerUser.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
Date = date.GetValueOrDefault(DateTime.UtcNow)
});
}
await _eventWriteService.CreateManyAsync(eventMessages);
}
public async Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type,
DateTime? date = null)
{
var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync();
if (!CanUseProviderEvents(providerAbilities, providerOrganization.ProviderId))
{
return;
}
var e = new EventMessage(_currentContext)
{
ProviderId = providerOrganization.ProviderId,
ProviderOrganizationId = providerOrganization.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
Date = date.GetValueOrDefault(DateTime.UtcNow)
});
};
await _eventWriteService.CreateAsync(e);
}
await _eventWriteService.CreateManyAsync(eventMessages);
}
public async Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null)
{
if (!organization.Enabled || !organization.UseEvents)
private async Task<Guid?> GetProviderIdAsync(Guid? orgId)
{
return;
}
var e = new EventMessage(_currentContext)
{
OrganizationId = organization.Id,
ProviderId = await GetProviderIdAsync(organization.Id),
Type = type,
ActingUserId = _currentContext?.UserId,
Date = date.GetValueOrDefault(DateTime.UtcNow),
InstallationId = GetInstallationId(),
};
await _eventWriteService.CreateAsync(e);
}
public async Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null)
{
await LogProviderUsersEventAsync(new[] { (providerUser, type, date) });
}
public async Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events)
{
var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync();
var eventMessages = new List<IEvent>();
foreach (var (providerUser, type, date) in events)
{
if (!CanUseProviderEvents(providerAbilities, providerUser.ProviderId))
if (_currentContext == null || !orgId.HasValue)
{
continue;
return null;
}
eventMessages.Add(new EventMessage(_currentContext)
return await _currentContext.ProviderIdForOrg(orgId.Value);
}
private Guid? GetInstallationId()
{
if (_currentContext == null)
{
ProviderId = providerUser.ProviderId,
UserId = providerUser.UserId,
ProviderUserId = providerUser.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
Date = date.GetValueOrDefault(DateTime.UtcNow)
});
return null;
}
return _currentContext.InstallationId;
}
await _eventWriteService.CreateManyAsync(eventMessages);
}
public async Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type,
DateTime? date = null)
{
var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync();
if (!CanUseProviderEvents(providerAbilities, providerOrganization.ProviderId))
private bool CanUseEvents(IDictionary<Guid, OrganizationAbility> orgAbilities, Guid orgId)
{
return;
return orgAbilities != null && orgAbilities.ContainsKey(orgId) &&
orgAbilities[orgId].Enabled && orgAbilities[orgId].UseEvents;
}
var e = new EventMessage(_currentContext)
private bool CanUseProviderEvents(IDictionary<Guid, ProviderAbility> providerAbilities, Guid providerId)
{
ProviderId = providerOrganization.ProviderId,
ProviderOrganizationId = providerOrganization.Id,
Type = type,
ActingUserId = _currentContext?.UserId,
Date = date.GetValueOrDefault(DateTime.UtcNow)
};
await _eventWriteService.CreateAsync(e);
}
private async Task<Guid?> GetProviderIdAsync(Guid? orgId)
{
if (_currentContext == null || !orgId.HasValue)
{
return null;
return providerAbilities != null && providerAbilities.ContainsKey(providerId) &&
providerAbilities[providerId].Enabled && providerAbilities[providerId].UseEvents;
}
return await _currentContext.ProviderIdForOrg(orgId.Value);
}
private Guid? GetInstallationId()
{
if (_currentContext == null)
{
return null;
}
return _currentContext.InstallationId;
}
private bool CanUseEvents(IDictionary<Guid, OrganizationAbility> orgAbilities, Guid orgId)
{
return orgAbilities != null && orgAbilities.ContainsKey(orgId) &&
orgAbilities[orgId].Enabled && orgAbilities[orgId].UseEvents;
}
private bool CanUseProviderEvents(IDictionary<Guid, ProviderAbility> providerAbilities, Guid providerId)
{
return providerAbilities != null && providerAbilities.ContainsKey(providerId) &&
providerAbilities[providerId].Enabled && providerAbilities[providerId].UseEvents;
}
}

View File

@ -5,81 +5,82 @@ using Bit.Core.Models.Business;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
namespace Bit.Core.Services;
public class GroupService : IGroupService
namespace Bit.Core.Services
{
private readonly IEventService _eventService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IGroupRepository _groupRepository;
private readonly IReferenceEventService _referenceEventService;
public GroupService(
IEventService eventService,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IGroupRepository groupRepository,
IReferenceEventService referenceEventService)
public class GroupService : IGroupService
{
_eventService = eventService;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_groupRepository = groupRepository;
_referenceEventService = referenceEventService;
}
private readonly IEventService _eventService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IGroupRepository _groupRepository;
private readonly IReferenceEventService _referenceEventService;
public async Task SaveAsync(Group group, IEnumerable<SelectionReadOnly> collections = null)
{
var org = await _organizationRepository.GetByIdAsync(group.OrganizationId);
if (org == null)
public GroupService(
IEventService eventService,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IGroupRepository groupRepository,
IReferenceEventService referenceEventService)
{
throw new BadRequestException("Organization not found");
_eventService = eventService;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_groupRepository = groupRepository;
_referenceEventService = referenceEventService;
}
if (!org.UseGroups)
public async Task SaveAsync(Group group, IEnumerable<SelectionReadOnly> collections = null)
{
throw new BadRequestException("This organization cannot use groups.");
}
if (group.Id == default(Guid))
{
group.CreationDate = group.RevisionDate = DateTime.UtcNow;
if (collections == null)
var org = await _organizationRepository.GetByIdAsync(group.OrganizationId);
if (org == null)
{
await _groupRepository.CreateAsync(group);
throw new BadRequestException("Organization not found");
}
if (!org.UseGroups)
{
throw new BadRequestException("This organization cannot use groups.");
}
if (group.Id == default(Guid))
{
group.CreationDate = group.RevisionDate = DateTime.UtcNow;
if (collections == null)
{
await _groupRepository.CreateAsync(group);
}
else
{
await _groupRepository.CreateAsync(group, collections);
}
await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Created);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.GroupCreated, org));
}
else
{
await _groupRepository.CreateAsync(group, collections);
group.RevisionDate = DateTime.UtcNow;
await _groupRepository.ReplaceAsync(group, collections ?? new List<SelectionReadOnly>());
await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Updated);
}
await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Created);
await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.GroupCreated, org));
}
else
public async Task DeleteAsync(Group group)
{
group.RevisionDate = DateTime.UtcNow;
await _groupRepository.ReplaceAsync(group, collections ?? new List<SelectionReadOnly>());
await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Updated);
await _groupRepository.DeleteAsync(group);
await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Deleted);
}
}
public async Task DeleteAsync(Group group)
{
await _groupRepository.DeleteAsync(group);
await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Deleted);
}
public async Task DeleteUserAsync(Group group, Guid organizationUserId)
{
var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId);
if (orgUser == null || orgUser.OrganizationId != group.OrganizationId)
public async Task DeleteUserAsync(Group group, Guid organizationUserId)
{
throw new NotFoundException();
var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId);
if (orgUser == null || orgUser.OrganizationId != group.OrganizationId)
{
throw new NotFoundException();
}
await _groupRepository.DeleteUserAsync(group.Id, organizationUserId);
await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_UpdatedGroups);
}
await _groupRepository.DeleteUserAsync(group.Id, organizationUserId);
await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_UpdatedGroups);
}
}

View File

@ -8,124 +8,125 @@ using Bit.Core.Settings;
using Bit.Core.Tokens;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class HCaptchaValidationService : ICaptchaValidationService
namespace Bit.Core.Services
{
private readonly ILogger<HCaptchaValidationService> _logger;
private readonly IHttpClientFactory _httpClientFactory;
private readonly GlobalSettings _globalSettings;
private readonly IDataProtectorTokenFactory<HCaptchaTokenable> _tokenizer;
public HCaptchaValidationService(
ILogger<HCaptchaValidationService> logger,
IHttpClientFactory httpClientFactory,
IDataProtectorTokenFactory<HCaptchaTokenable> tokenizer,
GlobalSettings globalSettings)
public class HCaptchaValidationService : ICaptchaValidationService
{
_logger = logger;
_httpClientFactory = httpClientFactory;
_globalSettings = globalSettings;
_tokenizer = tokenizer;
}
private readonly ILogger<HCaptchaValidationService> _logger;
private readonly IHttpClientFactory _httpClientFactory;
private readonly GlobalSettings _globalSettings;
private readonly IDataProtectorTokenFactory<HCaptchaTokenable> _tokenizer;
public string SiteKeyResponseKeyName => "HCaptcha_SiteKey";
public string SiteKey => _globalSettings.Captcha.HCaptchaSiteKey;
public string GenerateCaptchaBypassToken(User user) => _tokenizer.Protect(new HCaptchaTokenable(user));
public async Task<CaptchaResponse> ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress,
User user = null)
{
var response = new CaptchaResponse { Success = false };
if (string.IsNullOrWhiteSpace(captchaResponse))
public HCaptchaValidationService(
ILogger<HCaptchaValidationService> logger,
IHttpClientFactory httpClientFactory,
IDataProtectorTokenFactory<HCaptchaTokenable> tokenizer,
GlobalSettings globalSettings)
{
return response;
_logger = logger;
_httpClientFactory = httpClientFactory;
_globalSettings = globalSettings;
_tokenizer = tokenizer;
}
if (user != null && ValidateCaptchaBypassToken(captchaResponse, user))
{
response.Success = true;
return response;
}
public string SiteKeyResponseKeyName => "HCaptcha_SiteKey";
public string SiteKey => _globalSettings.Captcha.HCaptchaSiteKey;
var httpClient = _httpClientFactory.CreateClient("HCaptchaValidationService");
public string GenerateCaptchaBypassToken(User user) => _tokenizer.Protect(new HCaptchaTokenable(user));
var requestMessage = new HttpRequestMessage
public async Task<CaptchaResponse> ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress,
User user = null)
{
Method = HttpMethod.Post,
RequestUri = new Uri("https://hcaptcha.com/siteverify"),
Content = new FormUrlEncodedContent(new Dictionary<string, string>
var response = new CaptchaResponse { Success = false };
if (string.IsNullOrWhiteSpace(captchaResponse))
{
{ "response", captchaResponse.TrimStart("hcaptcha|".ToCharArray()) },
{ "secret", _globalSettings.Captcha.HCaptchaSecretKey },
{ "sitekey", SiteKey },
{ "remoteip", clientIpAddress }
})
};
return response;
}
HttpResponseMessage responseMessage;
try
{
responseMessage = await httpClient.SendAsync(requestMessage);
}
catch (Exception e)
{
_logger.LogError(11389, e, "Unable to verify with HCaptcha.");
if (user != null && ValidateCaptchaBypassToken(captchaResponse, user))
{
response.Success = true;
return response;
}
var httpClient = _httpClientFactory.CreateClient("HCaptchaValidationService");
var requestMessage = new HttpRequestMessage
{
Method = HttpMethod.Post,
RequestUri = new Uri("https://hcaptcha.com/siteverify"),
Content = new FormUrlEncodedContent(new Dictionary<string, string>
{
{ "response", captchaResponse.TrimStart("hcaptcha|".ToCharArray()) },
{ "secret", _globalSettings.Captcha.HCaptchaSecretKey },
{ "sitekey", SiteKey },
{ "remoteip", clientIpAddress }
})
};
HttpResponseMessage responseMessage;
try
{
responseMessage = await httpClient.SendAsync(requestMessage);
}
catch (Exception e)
{
_logger.LogError(11389, e, "Unable to verify with HCaptcha.");
return response;
}
if (!responseMessage.IsSuccessStatusCode)
{
return response;
}
using var hcaptchaResponse = await responseMessage.Content.ReadFromJsonAsync<HCaptchaResponse>();
response.Success = hcaptchaResponse.Success;
var score = hcaptchaResponse.Score.GetValueOrDefault();
response.MaybeBot = score >= _globalSettings.Captcha.MaybeBotScoreThreshold;
response.IsBot = score >= _globalSettings.Captcha.IsBotScoreThreshold;
response.Score = score;
return response;
}
if (!responseMessage.IsSuccessStatusCode)
public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null)
{
return response;
if (user == null)
{
return currentContext.IsBot || _globalSettings.Captcha.ForceCaptchaRequired;
}
var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts;
var failedLoginCount = user?.FailedLoginCount ?? 0;
var cloudEmailUnverified = !_globalSettings.SelfHosted && !user.EmailVerified;
return currentContext.IsBot ||
_globalSettings.Captcha.ForceCaptchaRequired ||
cloudEmailUnverified ||
failedLoginCeiling > 0 && failedLoginCount >= failedLoginCeiling;
}
using var hcaptchaResponse = await responseMessage.Content.ReadFromJsonAsync<HCaptchaResponse>();
response.Success = hcaptchaResponse.Success;
var score = hcaptchaResponse.Score.GetValueOrDefault();
response.MaybeBot = score >= _globalSettings.Captcha.MaybeBotScoreThreshold;
response.IsBot = score >= _globalSettings.Captcha.IsBotScoreThreshold;
response.Score = score;
return response;
}
private static bool TokenIsValidApiKey(string bypassToken, User user) =>
!string.IsNullOrWhiteSpace(bypassToken) && user != null && user.ApiKey == bypassToken;
public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null)
{
if (user == null)
private bool TokenIsValidCaptchaBypassToken(string encryptedToken, User user)
{
return currentContext.IsBot || _globalSettings.Captcha.ForceCaptchaRequired;
return _tokenizer.TryUnprotect(encryptedToken, out var data) &&
data.Valid && data.TokenIsValid(user);
}
var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts;
var failedLoginCount = user?.FailedLoginCount ?? 0;
var cloudEmailUnverified = !_globalSettings.SelfHosted && !user.EmailVerified;
return currentContext.IsBot ||
_globalSettings.Captcha.ForceCaptchaRequired ||
cloudEmailUnverified ||
failedLoginCeiling > 0 && failedLoginCount >= failedLoginCeiling;
}
private bool ValidateCaptchaBypassToken(string bypassToken, User user) =>
TokenIsValidApiKey(bypassToken, user) || TokenIsValidCaptchaBypassToken(bypassToken, user);
private static bool TokenIsValidApiKey(string bypassToken, User user) =>
!string.IsNullOrWhiteSpace(bypassToken) && user != null && user.ApiKey == bypassToken;
public class HCaptchaResponse : IDisposable
{
[JsonPropertyName("success")]
public bool Success { get; set; }
[JsonPropertyName("score")]
public double? Score { get; set; }
[JsonPropertyName("score_reason")]
public List<string> ScoreReason { get; set; }
private bool TokenIsValidCaptchaBypassToken(string encryptedToken, User user)
{
return _tokenizer.TryUnprotect(encryptedToken, out var data) &&
data.Valid && data.TokenIsValid(user);
}
private bool ValidateCaptchaBypassToken(string bypassToken, User user) =>
TokenIsValidApiKey(bypassToken, user) || TokenIsValidCaptchaBypassToken(bypassToken, user);
public class HCaptchaResponse : IDisposable
{
[JsonPropertyName("success")]
public bool Success { get; set; }
[JsonPropertyName("score")]
public double? Score { get; set; }
[JsonPropertyName("score_reason")]
public List<string> ScoreReason { get; set; }
public void Dispose() { }
public void Dispose() { }
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -2,35 +2,36 @@
using Bit.Core.Resources;
using Microsoft.Extensions.Localization;
namespace Bit.Core.Services;
public class I18nService : II18nService
namespace Bit.Core.Services
{
private readonly IStringLocalizer _localizer;
public I18nService(IStringLocalizerFactory factory)
public class I18nService : II18nService
{
var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName);
_localizer = factory.Create("SharedResources", assemblyName.Name);
}
private readonly IStringLocalizer _localizer;
public LocalizedString GetLocalizedHtmlString(string key)
{
return _localizer[key];
}
public I18nService(IStringLocalizerFactory factory)
{
var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName);
_localizer = factory.Create("SharedResources", assemblyName.Name);
}
public LocalizedString GetLocalizedHtmlString(string key, params object[] args)
{
return _localizer[key, args];
}
public LocalizedString GetLocalizedHtmlString(string key)
{
return _localizer[key];
}
public string Translate(string key, params object[] args)
{
return string.Format(GetLocalizedHtmlString(key).ToString(), args);
}
public LocalizedString GetLocalizedHtmlString(string key, params object[] args)
{
return _localizer[key, args];
}
public string T(string key, params object[] args)
{
return Translate(key, args);
public string Translate(string key, params object[] args)
{
return string.Format(GetLocalizedHtmlString(key).ToString(), args);
}
public string T(string key, params object[] args)
{
return Translate(key, args);
}
}
}

View File

@ -3,28 +3,29 @@ using Bit.Core.Resources;
using Microsoft.AspNetCore.Mvc.Localization;
using Microsoft.Extensions.Localization;
namespace Bit.Core.Services;
public class I18nViewLocalizer : IViewLocalizer
namespace Bit.Core.Services
{
private readonly IStringLocalizer _stringLocalizer;
private readonly IHtmlLocalizer _htmlLocalizer;
public I18nViewLocalizer(IStringLocalizerFactory stringFactory,
IHtmlLocalizerFactory htmlFactory)
public class I18nViewLocalizer : IViewLocalizer
{
var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName);
_stringLocalizer = stringFactory.Create("SharedResources", assemblyName.Name);
_htmlLocalizer = htmlFactory.Create("SharedResources", assemblyName.Name);
private readonly IStringLocalizer _stringLocalizer;
private readonly IHtmlLocalizer _htmlLocalizer;
public I18nViewLocalizer(IStringLocalizerFactory stringFactory,
IHtmlLocalizerFactory htmlFactory)
{
var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName);
_stringLocalizer = stringFactory.Create("SharedResources", assemblyName.Name);
_htmlLocalizer = htmlFactory.Create("SharedResources", assemblyName.Name);
}
public LocalizedHtmlString this[string name] => _htmlLocalizer[name];
public LocalizedHtmlString this[string name, params object[] args] => _htmlLocalizer[name, args];
public IEnumerable<LocalizedString> GetAllStrings(bool includeParentCultures) =>
_stringLocalizer.GetAllStrings(includeParentCultures);
public LocalizedString GetString(string name) => _stringLocalizer[name];
public LocalizedString GetString(string name, params object[] arguments) =>
_stringLocalizer[name, arguments];
}
public LocalizedHtmlString this[string name] => _htmlLocalizer[name];
public LocalizedHtmlString this[string name, params object[] args] => _htmlLocalizer[name, args];
public IEnumerable<LocalizedString> GetAllStrings(bool includeParentCultures) =>
_stringLocalizer.GetAllStrings(includeParentCultures);
public LocalizedString GetString(string name) => _stringLocalizer[name];
public LocalizedString GetString(string name, params object[] arguments) =>
_stringLocalizer[name, arguments];
}

View File

@ -4,96 +4,97 @@ using Bit.Core.Models.Data;
using Bit.Core.Models.Data.Organizations;
using Bit.Core.Repositories;
namespace Bit.Core.Services;
public class InMemoryApplicationCacheService : IApplicationCacheService
namespace Bit.Core.Services
{
private readonly IOrganizationRepository _organizationRepository;
private readonly IProviderRepository _providerRepository;
private DateTime _lastOrgAbilityRefresh = DateTime.MinValue;
private IDictionary<Guid, OrganizationAbility> _orgAbilities;
private TimeSpan _orgAbilitiesRefreshInterval = TimeSpan.FromMinutes(10);
private IDictionary<Guid, ProviderAbility> _providerAbilities;
public InMemoryApplicationCacheService(
IOrganizationRepository organizationRepository, IProviderRepository providerRepository)
public class InMemoryApplicationCacheService : IApplicationCacheService
{
_organizationRepository = organizationRepository;
_providerRepository = providerRepository;
}
private readonly IOrganizationRepository _organizationRepository;
private readonly IProviderRepository _providerRepository;
private DateTime _lastOrgAbilityRefresh = DateTime.MinValue;
private IDictionary<Guid, OrganizationAbility> _orgAbilities;
private TimeSpan _orgAbilitiesRefreshInterval = TimeSpan.FromMinutes(10);
public virtual async Task<IDictionary<Guid, OrganizationAbility>> GetOrganizationAbilitiesAsync()
{
await InitOrganizationAbilitiesAsync();
return _orgAbilities;
}
private IDictionary<Guid, ProviderAbility> _providerAbilities;
public virtual async Task<IDictionary<Guid, ProviderAbility>> GetProviderAbilitiesAsync()
{
await InitProviderAbilitiesAsync();
return _providerAbilities;
}
public virtual async Task UpsertProviderAbilityAsync(Provider provider)
{
await InitProviderAbilitiesAsync();
var newAbility = new ProviderAbility(provider);
if (_providerAbilities.ContainsKey(provider.Id))
public InMemoryApplicationCacheService(
IOrganizationRepository organizationRepository, IProviderRepository providerRepository)
{
_providerAbilities[provider.Id] = newAbility;
}
else
{
_providerAbilities.Add(provider.Id, newAbility);
}
}
public virtual async Task UpsertOrganizationAbilityAsync(Organization organization)
{
await InitOrganizationAbilitiesAsync();
var newAbility = new OrganizationAbility(organization);
if (_orgAbilities.ContainsKey(organization.Id))
{
_orgAbilities[organization.Id] = newAbility;
}
else
{
_orgAbilities.Add(organization.Id, newAbility);
}
}
public virtual Task DeleteOrganizationAbilityAsync(Guid organizationId)
{
if (_orgAbilities != null && _orgAbilities.ContainsKey(organizationId))
{
_orgAbilities.Remove(organizationId);
_organizationRepository = organizationRepository;
_providerRepository = providerRepository;
}
return Task.FromResult(0);
}
private async Task InitOrganizationAbilitiesAsync()
{
var now = DateTime.UtcNow;
if (_orgAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval)
public virtual async Task<IDictionary<Guid, OrganizationAbility>> GetOrganizationAbilitiesAsync()
{
var abilities = await _organizationRepository.GetManyAbilitiesAsync();
_orgAbilities = abilities.ToDictionary(a => a.Id);
_lastOrgAbilityRefresh = now;
await InitOrganizationAbilitiesAsync();
return _orgAbilities;
}
}
private async Task InitProviderAbilitiesAsync()
{
var now = DateTime.UtcNow;
if (_providerAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval)
public virtual async Task<IDictionary<Guid, ProviderAbility>> GetProviderAbilitiesAsync()
{
var abilities = await _providerRepository.GetManyAbilitiesAsync();
_providerAbilities = abilities.ToDictionary(a => a.Id);
_lastOrgAbilityRefresh = now;
await InitProviderAbilitiesAsync();
return _providerAbilities;
}
public virtual async Task UpsertProviderAbilityAsync(Provider provider)
{
await InitProviderAbilitiesAsync();
var newAbility = new ProviderAbility(provider);
if (_providerAbilities.ContainsKey(provider.Id))
{
_providerAbilities[provider.Id] = newAbility;
}
else
{
_providerAbilities.Add(provider.Id, newAbility);
}
}
public virtual async Task UpsertOrganizationAbilityAsync(Organization organization)
{
await InitOrganizationAbilitiesAsync();
var newAbility = new OrganizationAbility(organization);
if (_orgAbilities.ContainsKey(organization.Id))
{
_orgAbilities[organization.Id] = newAbility;
}
else
{
_orgAbilities.Add(organization.Id, newAbility);
}
}
public virtual Task DeleteOrganizationAbilityAsync(Guid organizationId)
{
if (_orgAbilities != null && _orgAbilities.ContainsKey(organizationId))
{
_orgAbilities.Remove(organizationId);
}
return Task.FromResult(0);
}
private async Task InitOrganizationAbilitiesAsync()
{
var now = DateTime.UtcNow;
if (_orgAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval)
{
var abilities = await _organizationRepository.GetManyAbilitiesAsync();
_orgAbilities = abilities.ToDictionary(a => a.Id);
_lastOrgAbilityRefresh = now;
}
}
private async Task InitProviderAbilitiesAsync()
{
var now = DateTime.UtcNow;
if (_providerAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval)
{
var abilities = await _providerRepository.GetManyAbilitiesAsync();
_providerAbilities = abilities.ToDictionary(a => a.Id);
_lastOrgAbilityRefresh = now;
}
}
}
}

View File

@ -5,61 +5,62 @@ using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.Azure.ServiceBus;
namespace Bit.Core.Services;
public class InMemoryServiceBusApplicationCacheService : InMemoryApplicationCacheService, IApplicationCacheService
namespace Bit.Core.Services
{
private readonly TopicClient _topicClient;
private readonly string _subName;
public InMemoryServiceBusApplicationCacheService(
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
GlobalSettings globalSettings)
: base(organizationRepository, providerRepository)
public class InMemoryServiceBusApplicationCacheService : InMemoryApplicationCacheService, IApplicationCacheService
{
_subName = CoreHelpers.GetApplicationCacheServiceBusSubcriptionName(globalSettings);
_topicClient = new TopicClient(globalSettings.ServiceBus.ConnectionString,
globalSettings.ServiceBus.ApplicationCacheTopicName);
}
private readonly TopicClient _topicClient;
private readonly string _subName;
public override async Task UpsertOrganizationAbilityAsync(Organization organization)
{
await base.UpsertOrganizationAbilityAsync(organization);
var message = new Message
public InMemoryServiceBusApplicationCacheService(
IOrganizationRepository organizationRepository,
IProviderRepository providerRepository,
GlobalSettings globalSettings)
: base(organizationRepository, providerRepository)
{
Label = _subName,
UserProperties =
{
{ "type", (byte)ApplicationCacheMessageType.UpsertOrganizationAbility },
{ "id", organization.Id },
}
};
var task = _topicClient.SendAsync(message);
}
_subName = CoreHelpers.GetApplicationCacheServiceBusSubcriptionName(globalSettings);
_topicClient = new TopicClient(globalSettings.ServiceBus.ConnectionString,
globalSettings.ServiceBus.ApplicationCacheTopicName);
}
public override async Task DeleteOrganizationAbilityAsync(Guid organizationId)
{
await base.DeleteOrganizationAbilityAsync(organizationId);
var message = new Message
public override async Task UpsertOrganizationAbilityAsync(Organization organization)
{
Label = _subName,
UserProperties =
await base.UpsertOrganizationAbilityAsync(organization);
var message = new Message
{
{ "type", (byte)ApplicationCacheMessageType.DeleteOrganizationAbility },
{ "id", organizationId },
}
};
var task = _topicClient.SendAsync(message);
}
Label = _subName,
UserProperties =
{
{ "type", (byte)ApplicationCacheMessageType.UpsertOrganizationAbility },
{ "id", organization.Id },
}
};
var task = _topicClient.SendAsync(message);
}
public async Task BaseUpsertOrganizationAbilityAsync(Organization organization)
{
await base.UpsertOrganizationAbilityAsync(organization);
}
public override async Task DeleteOrganizationAbilityAsync(Guid organizationId)
{
await base.DeleteOrganizationAbilityAsync(organizationId);
var message = new Message
{
Label = _subName,
UserProperties =
{
{ "type", (byte)ApplicationCacheMessageType.DeleteOrganizationAbility },
{ "id", organizationId },
}
};
var task = _topicClient.SendAsync(message);
}
public async Task BaseDeleteOrganizationAbilityAsync(Guid organizationId)
{
await base.DeleteOrganizationAbilityAsync(organizationId);
public async Task BaseUpsertOrganizationAbilityAsync(Organization organization)
{
await base.UpsertOrganizationAbilityAsync(organization);
}
public async Task BaseDeleteOrganizationAbilityAsync(Guid organizationId)
{
await base.DeleteOrganizationAbilityAsync(organizationId);
}
}
}

View File

@ -10,251 +10,252 @@ using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class LicensingService : ILicensingService
namespace Bit.Core.Services
{
private readonly X509Certificate2 _certificate;
private readonly IGlobalSettings _globalSettings;
private readonly IUserRepository _userRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IMailService _mailService;
private readonly ILogger<LicensingService> _logger;
private IDictionary<Guid, DateTime> _userCheckCache = new Dictionary<Guid, DateTime>();
public LicensingService(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IMailService mailService,
IWebHostEnvironment environment,
ILogger<LicensingService> logger,
IGlobalSettings globalSettings)
public class LicensingService : ILicensingService
{
_userRepository = userRepository;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_mailService = mailService;
_logger = logger;
_globalSettings = globalSettings;
private readonly X509Certificate2 _certificate;
private readonly IGlobalSettings _globalSettings;
private readonly IUserRepository _userRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IMailService _mailService;
private readonly ILogger<LicensingService> _logger;
var certThumbprint = environment.IsDevelopment() ?
"207E64A231E8AA32AAF68A61037C075EBEBD553F" :
"B34876439FCDA2846505B2EFBBA6C4A951313EBE";
if (_globalSettings.SelfHosted)
private IDictionary<Guid, DateTime> _userCheckCache = new Dictionary<Guid, DateTime>();
public LicensingService(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IMailService mailService,
IWebHostEnvironment environment,
ILogger<LicensingService> logger,
IGlobalSettings globalSettings)
{
_certificate = CoreHelpers.GetEmbeddedCertificateAsync(environment.IsDevelopment() ? "licensing_dev.cer" : "licensing.cer", null)
.GetAwaiter().GetResult();
}
else if (CoreHelpers.SettingHasValue(_globalSettings.Storage?.ConnectionString) &&
CoreHelpers.SettingHasValue(_globalSettings.LicenseCertificatePassword))
{
_certificate = CoreHelpers.GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates",
"licensing.pfx", _globalSettings.LicenseCertificatePassword)
.GetAwaiter().GetResult();
}
else
{
_certificate = CoreHelpers.GetCertificate(certThumbprint);
_userRepository = userRepository;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_mailService = mailService;
_logger = logger;
_globalSettings = globalSettings;
var certThumbprint = environment.IsDevelopment() ?
"207E64A231E8AA32AAF68A61037C075EBEBD553F" :
"B34876439FCDA2846505B2EFBBA6C4A951313EBE";
if (_globalSettings.SelfHosted)
{
_certificate = CoreHelpers.GetEmbeddedCertificateAsync(environment.IsDevelopment() ? "licensing_dev.cer" : "licensing.cer", null)
.GetAwaiter().GetResult();
}
else if (CoreHelpers.SettingHasValue(_globalSettings.Storage?.ConnectionString) &&
CoreHelpers.SettingHasValue(_globalSettings.LicenseCertificatePassword))
{
_certificate = CoreHelpers.GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates",
"licensing.pfx", _globalSettings.LicenseCertificatePassword)
.GetAwaiter().GetResult();
}
else
{
_certificate = CoreHelpers.GetCertificate(certThumbprint);
}
if (_certificate == null || !_certificate.Thumbprint.Equals(CoreHelpers.CleanCertificateThumbprint(certThumbprint),
StringComparison.InvariantCultureIgnoreCase))
{
throw new Exception("Invalid licensing certificate.");
}
if (_globalSettings.SelfHosted && !CoreHelpers.SettingHasValue(_globalSettings.LicenseDirectory))
{
throw new InvalidOperationException("No license directory.");
}
}
if (_certificate == null || !_certificate.Thumbprint.Equals(CoreHelpers.CleanCertificateThumbprint(certThumbprint),
StringComparison.InvariantCultureIgnoreCase))
public async Task ValidateOrganizationsAsync()
{
throw new Exception("Invalid licensing certificate.");
if (!_globalSettings.SelfHosted)
{
return;
}
var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, null,
"Validating licenses for {0} organizations.", enabledOrgs.Count);
foreach (var org in enabledOrgs)
{
var license = await ReadOrganizationLicenseAsync(org);
if (license == null)
{
await DisableOrganizationAsync(org, null, "No license file.");
continue;
}
var totalLicensedOrgs = enabledOrgs.Count(o => o.LicenseKey.Equals(license.LicenseKey));
if (totalLicensedOrgs > 1)
{
await DisableOrganizationAsync(org, license, "Multiple organizations.");
continue;
}
if (!license.VerifyData(org, _globalSettings))
{
await DisableOrganizationAsync(org, license, "Invalid data.");
continue;
}
if (!license.VerifySignature(_certificate))
{
await DisableOrganizationAsync(org, license, "Invalid signature.");
continue;
}
}
}
if (_globalSettings.SelfHosted && !CoreHelpers.SettingHasValue(_globalSettings.LicenseDirectory))
private async Task DisableOrganizationAsync(Organization org, ILicense license, string reason)
{
throw new InvalidOperationException("No license directory.");
}
}
_logger.LogInformation(Constants.BypassFiltersEventId, null,
"Organization {0} ({1}) has an invalid license and is being disabled. Reason: {2}",
org.Id, org.Name, reason);
org.Enabled = false;
org.ExpirationDate = license?.Expires ?? DateTime.UtcNow;
org.RevisionDate = DateTime.UtcNow;
await _organizationRepository.ReplaceAsync(org);
public async Task ValidateOrganizationsAsync()
{
if (!_globalSettings.SelfHosted)
{
return;
await _mailService.SendLicenseExpiredAsync(new List<string> { org.BillingEmail }, org.Name);
}
var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync();
_logger.LogInformation(Constants.BypassFiltersEventId, null,
"Validating licenses for {0} organizations.", enabledOrgs.Count);
foreach (var org in enabledOrgs)
public async Task ValidateUsersAsync()
{
var license = await ReadOrganizationLicenseAsync(org);
if (!_globalSettings.SelfHosted)
{
return;
}
var premiumUsers = await _userRepository.GetManyByPremiumAsync(true);
_logger.LogInformation(Constants.BypassFiltersEventId, null,
"Validating premium for {0} users.", premiumUsers.Count);
foreach (var user in premiumUsers)
{
await ProcessUserValidationAsync(user);
}
}
public async Task<bool> ValidateUserPremiumAsync(User user)
{
if (!_globalSettings.SelfHosted)
{
return user.Premium;
}
if (!user.Premium)
{
return false;
}
// Only check once per day
var now = DateTime.UtcNow;
if (_userCheckCache.ContainsKey(user.Id))
{
var lastCheck = _userCheckCache[user.Id];
if (lastCheck < now && now - lastCheck < TimeSpan.FromDays(1))
{
return user.Premium;
}
else
{
_userCheckCache[user.Id] = now;
}
}
else
{
_userCheckCache.Add(user.Id, now);
}
_logger.LogInformation(Constants.BypassFiltersEventId, null,
"Validating premium license for user {0}({1}).", user.Id, user.Email);
return await ProcessUserValidationAsync(user);
}
private async Task<bool> ProcessUserValidationAsync(User user)
{
var license = ReadUserLicense(user);
if (license == null)
{
await DisableOrganizationAsync(org, null, "No license file.");
continue;
await DisablePremiumAsync(user, null, "No license file.");
return false;
}
var totalLicensedOrgs = enabledOrgs.Count(o => o.LicenseKey.Equals(license.LicenseKey));
if (totalLicensedOrgs > 1)
if (!license.VerifyData(user))
{
await DisableOrganizationAsync(org, license, "Multiple organizations.");
continue;
}
if (!license.VerifyData(org, _globalSettings))
{
await DisableOrganizationAsync(org, license, "Invalid data.");
continue;
await DisablePremiumAsync(user, license, "Invalid data.");
return false;
}
if (!license.VerifySignature(_certificate))
{
await DisableOrganizationAsync(org, license, "Invalid signature.");
continue;
await DisablePremiumAsync(user, license, "Invalid signature.");
return false;
}
}
}
private async Task DisableOrganizationAsync(Organization org, ILicense license, string reason)
{
_logger.LogInformation(Constants.BypassFiltersEventId, null,
"Organization {0} ({1}) has an invalid license and is being disabled. Reason: {2}",
org.Id, org.Name, reason);
org.Enabled = false;
org.ExpirationDate = license?.Expires ?? DateTime.UtcNow;
org.RevisionDate = DateTime.UtcNow;
await _organizationRepository.ReplaceAsync(org);
await _mailService.SendLicenseExpiredAsync(new List<string> { org.BillingEmail }, org.Name);
}
public async Task ValidateUsersAsync()
{
if (!_globalSettings.SelfHosted)
{
return;
return true;
}
var premiumUsers = await _userRepository.GetManyByPremiumAsync(true);
_logger.LogInformation(Constants.BypassFiltersEventId, null,
"Validating premium for {0} users.", premiumUsers.Count);
foreach (var user in premiumUsers)
private async Task DisablePremiumAsync(User user, ILicense license, string reason)
{
await ProcessUserValidationAsync(user);
}
}
_logger.LogInformation(Constants.BypassFiltersEventId, null,
"User {0}({1}) has an invalid license and premium is being disabled. Reason: {2}",
user.Id, user.Email, reason);
public async Task<bool> ValidateUserPremiumAsync(User user)
{
if (!_globalSettings.SelfHosted)
{
return user.Premium;
user.Premium = false;
user.PremiumExpirationDate = license?.Expires ?? DateTime.UtcNow;
user.RevisionDate = DateTime.UtcNow;
await _userRepository.ReplaceAsync(user);
await _mailService.SendLicenseExpiredAsync(new List<string> { user.Email });
}
if (!user.Premium)
public bool VerifyLicense(ILicense license)
{
return false;
return license.VerifySignature(_certificate);
}
// Only check once per day
var now = DateTime.UtcNow;
if (_userCheckCache.ContainsKey(user.Id))
public byte[] SignLicense(ILicense license)
{
var lastCheck = _userCheckCache[user.Id];
if (lastCheck < now && now - lastCheck < TimeSpan.FromDays(1))
if (_globalSettings.SelfHosted || !_certificate.HasPrivateKey)
{
return user.Premium;
throw new InvalidOperationException("Cannot sign licenses.");
}
else
return license.Sign(_certificate);
}
private UserLicense ReadUserLicense(User user)
{
var filePath = $"{_globalSettings.LicenseDirectory}/user/{user.Id}.json";
if (!File.Exists(filePath))
{
_userCheckCache[user.Id] = now;
return null;
}
var data = File.ReadAllText(filePath, Encoding.UTF8);
return JsonSerializer.Deserialize<UserLicense>(data);
}
else
public Task<OrganizationLicense> ReadOrganizationLicenseAsync(Organization organization) =>
ReadOrganizationLicenseAsync(organization.Id);
public async Task<OrganizationLicense> ReadOrganizationLicenseAsync(Guid organizationId)
{
_userCheckCache.Add(user.Id, now);
var filePath = Path.Combine(_globalSettings.LicenseDirectory, "organization", $"{organizationId}.json");
if (!File.Exists(filePath))
{
return null;
}
using var fs = File.OpenRead(filePath);
return await JsonSerializer.DeserializeAsync<OrganizationLicense>(fs);
}
_logger.LogInformation(Constants.BypassFiltersEventId, null,
"Validating premium license for user {0}({1}).", user.Id, user.Email);
return await ProcessUserValidationAsync(user);
}
private async Task<bool> ProcessUserValidationAsync(User user)
{
var license = ReadUserLicense(user);
if (license == null)
{
await DisablePremiumAsync(user, null, "No license file.");
return false;
}
if (!license.VerifyData(user))
{
await DisablePremiumAsync(user, license, "Invalid data.");
return false;
}
if (!license.VerifySignature(_certificate))
{
await DisablePremiumAsync(user, license, "Invalid signature.");
return false;
}
return true;
}
private async Task DisablePremiumAsync(User user, ILicense license, string reason)
{
_logger.LogInformation(Constants.BypassFiltersEventId, null,
"User {0}({1}) has an invalid license and premium is being disabled. Reason: {2}",
user.Id, user.Email, reason);
user.Premium = false;
user.PremiumExpirationDate = license?.Expires ?? DateTime.UtcNow;
user.RevisionDate = DateTime.UtcNow;
await _userRepository.ReplaceAsync(user);
await _mailService.SendLicenseExpiredAsync(new List<string> { user.Email });
}
public bool VerifyLicense(ILicense license)
{
return license.VerifySignature(_certificate);
}
public byte[] SignLicense(ILicense license)
{
if (_globalSettings.SelfHosted || !_certificate.HasPrivateKey)
{
throw new InvalidOperationException("Cannot sign licenses.");
}
return license.Sign(_certificate);
}
private UserLicense ReadUserLicense(User user)
{
var filePath = $"{_globalSettings.LicenseDirectory}/user/{user.Id}.json";
if (!File.Exists(filePath))
{
return null;
}
var data = File.ReadAllText(filePath, Encoding.UTF8);
return JsonSerializer.Deserialize<UserLicense>(data);
}
public Task<OrganizationLicense> ReadOrganizationLicenseAsync(Organization organization) =>
ReadOrganizationLicenseAsync(organization.Id);
public async Task<OrganizationLicense> ReadOrganizationLicenseAsync(Guid organizationId)
{
var filePath = Path.Combine(_globalSettings.LicenseDirectory, "organization", $"{organizationId}.json");
if (!File.Exists(filePath))
{
return null;
}
using var fs = File.OpenRead(filePath);
return await JsonSerializer.DeserializeAsync<OrganizationLicense>(fs);
}
}

View File

@ -3,194 +3,195 @@ using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Settings;
namespace Bit.Core.Services;
public class LocalAttachmentStorageService : IAttachmentStorageService
namespace Bit.Core.Services
{
private readonly string _baseAttachmentUrl;
private readonly string _baseDirPath;
private readonly string _baseTempDirPath;
public FileUploadType FileUploadType => FileUploadType.Direct;
public LocalAttachmentStorageService(
IGlobalSettings globalSettings)
public class LocalAttachmentStorageService : IAttachmentStorageService
{
_baseDirPath = globalSettings.Attachment.BaseDirectory;
_baseTempDirPath = $"{_baseDirPath}/temp";
_baseAttachmentUrl = globalSettings.Attachment.BaseUrl;
}
private readonly string _baseAttachmentUrl;
private readonly string _baseDirPath;
private readonly string _baseTempDirPath;
public async Task<string> GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData)
{
await InitAsync();
return $"{_baseAttachmentUrl}/{cipher.Id}/{attachmentData.AttachmentId}";
}
public FileUploadType FileUploadType => FileUploadType.Direct;
public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData)
{
await InitAsync();
var cipherDirPath = CipherDirectoryPath(cipher.Id, temp: false);
CreateDirectoryIfNotExists(cipherDirPath);
using (var fs = File.Create(AttachmentFilePath(cipherDirPath, attachmentData.AttachmentId)))
public LocalAttachmentStorageService(
IGlobalSettings globalSettings)
{
stream.Seek(0, SeekOrigin.Begin);
await stream.CopyToAsync(fs);
}
}
public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData)
{
await InitAsync();
var tempCipherOrgDirPath = OrganizationDirectoryPath(cipherId, organizationId, temp: true);
CreateDirectoryIfNotExists(tempCipherOrgDirPath);
using (var fs = File.Create(AttachmentFilePath(tempCipherOrgDirPath, attachmentData.AttachmentId)))
{
stream.Seek(0, SeekOrigin.Begin);
await stream.CopyToAsync(fs);
}
}
public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData)
{
await InitAsync();
var sourceFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true);
if (!File.Exists(sourceFilePath))
{
return;
_baseDirPath = globalSettings.Attachment.BaseDirectory;
_baseTempDirPath = $"{_baseDirPath}/temp";
_baseAttachmentUrl = globalSettings.Attachment.BaseUrl;
}
var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false);
if (!File.Exists(destFilePath))
public async Task<string> GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData)
{
return;
await InitAsync();
return $"{_baseAttachmentUrl}/{cipher.Id}/{attachmentData.AttachmentId}";
}
var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true);
DeleteFileIfExists(originalFilePath);
File.Move(destFilePath, originalFilePath);
DeleteFileIfExists(destFilePath);
File.Move(sourceFilePath, destFilePath);
}
public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer)
{
await InitAsync();
DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true));
var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true);
if (!File.Exists(originalFilePath))
public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData)
{
return;
await InitAsync();
var cipherDirPath = CipherDirectoryPath(cipher.Id, temp: false);
CreateDirectoryIfNotExists(cipherDirPath);
using (var fs = File.Create(AttachmentFilePath(cipherDirPath, attachmentData.AttachmentId)))
{
stream.Seek(0, SeekOrigin.Begin);
await stream.CopyToAsync(fs);
}
}
var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false);
DeleteFileIfExists(destFilePath);
File.Move(originalFilePath, destFilePath);
DeleteFileIfExists(originalFilePath);
}
public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData)
{
await InitAsync();
DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false));
}
public async Task CleanupAsync(Guid cipherId)
{
await InitAsync();
DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: true));
}
public async Task DeleteAttachmentsForCipherAsync(Guid cipherId)
{
await InitAsync();
DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: false));
}
public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId)
{
await InitAsync();
}
public async Task DeleteAttachmentsForUserAsync(Guid userId)
{
await InitAsync();
}
private void DeleteFileIfExists(string path)
{
if (File.Exists(path))
public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData)
{
File.Delete(path);
}
}
await InitAsync();
var tempCipherOrgDirPath = OrganizationDirectoryPath(cipherId, organizationId, temp: true);
CreateDirectoryIfNotExists(tempCipherOrgDirPath);
private void DeleteDirectoryIfExists(string path)
{
if (Directory.Exists(path))
{
Directory.Delete(path, true);
}
}
private void CreateDirectoryIfNotExists(string path)
{
if (!Directory.Exists(path))
{
Directory.CreateDirectory(path);
}
}
private Task InitAsync()
{
if (!Directory.Exists(_baseDirPath))
{
Directory.CreateDirectory(_baseDirPath);
using (var fs = File.Create(AttachmentFilePath(tempCipherOrgDirPath, attachmentData.AttachmentId)))
{
stream.Seek(0, SeekOrigin.Begin);
await stream.CopyToAsync(fs);
}
}
if (!Directory.Exists(_baseTempDirPath))
public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData)
{
Directory.CreateDirectory(_baseTempDirPath);
await InitAsync();
var sourceFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true);
if (!File.Exists(sourceFilePath))
{
return;
}
var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false);
if (!File.Exists(destFilePath))
{
return;
}
var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true);
DeleteFileIfExists(originalFilePath);
File.Move(destFilePath, originalFilePath);
DeleteFileIfExists(destFilePath);
File.Move(sourceFilePath, destFilePath);
}
return Task.FromResult(0);
}
private string CipherDirectoryPath(Guid cipherId, bool temp = false) =>
Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString());
private string OrganizationDirectoryPath(Guid cipherId, Guid organizationId, bool temp = false) =>
Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString(), organizationId.ToString());
private string AttachmentFilePath(string dir, string attachmentId) => Path.Combine(dir, attachmentId);
private string AttachmentFilePath(string attachmentId, Guid cipherId, Guid? organizationId = null,
bool temp = false) =>
organizationId.HasValue ?
AttachmentFilePath(OrganizationDirectoryPath(cipherId, organizationId.Value, temp), attachmentId) :
AttachmentFilePath(CipherDirectoryPath(cipherId, temp), attachmentId);
public Task<string> GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData)
=> Task.FromResult($"{cipher.Id}/attachment/{attachmentData.AttachmentId}");
public Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway)
{
long? length = null;
var path = AttachmentFilePath(attachmentData.AttachmentId, cipher.Id, temp: false);
if (!File.Exists(path))
public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer)
{
return Task.FromResult((false, length));
await InitAsync();
DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true));
var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true);
if (!File.Exists(originalFilePath))
{
return;
}
var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false);
DeleteFileIfExists(destFilePath);
File.Move(originalFilePath, destFilePath);
DeleteFileIfExists(originalFilePath);
}
length = new FileInfo(path).Length;
if (attachmentData.Size < length - leeway || attachmentData.Size > length + leeway)
public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData)
{
return Task.FromResult((false, length));
await InitAsync();
DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false));
}
return Task.FromResult((true, length));
public async Task CleanupAsync(Guid cipherId)
{
await InitAsync();
DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: true));
}
public async Task DeleteAttachmentsForCipherAsync(Guid cipherId)
{
await InitAsync();
DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: false));
}
public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId)
{
await InitAsync();
}
public async Task DeleteAttachmentsForUserAsync(Guid userId)
{
await InitAsync();
}
private void DeleteFileIfExists(string path)
{
if (File.Exists(path))
{
File.Delete(path);
}
}
private void DeleteDirectoryIfExists(string path)
{
if (Directory.Exists(path))
{
Directory.Delete(path, true);
}
}
private void CreateDirectoryIfNotExists(string path)
{
if (!Directory.Exists(path))
{
Directory.CreateDirectory(path);
}
}
private Task InitAsync()
{
if (!Directory.Exists(_baseDirPath))
{
Directory.CreateDirectory(_baseDirPath);
}
if (!Directory.Exists(_baseTempDirPath))
{
Directory.CreateDirectory(_baseTempDirPath);
}
return Task.FromResult(0);
}
private string CipherDirectoryPath(Guid cipherId, bool temp = false) =>
Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString());
private string OrganizationDirectoryPath(Guid cipherId, Guid organizationId, bool temp = false) =>
Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString(), organizationId.ToString());
private string AttachmentFilePath(string dir, string attachmentId) => Path.Combine(dir, attachmentId);
private string AttachmentFilePath(string attachmentId, Guid cipherId, Guid? organizationId = null,
bool temp = false) =>
organizationId.HasValue ?
AttachmentFilePath(OrganizationDirectoryPath(cipherId, organizationId.Value, temp), attachmentId) :
AttachmentFilePath(CipherDirectoryPath(cipherId, temp), attachmentId);
public Task<string> GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData)
=> Task.FromResult($"{cipher.Id}/attachment/{attachmentData.AttachmentId}");
public Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway)
{
long? length = null;
var path = AttachmentFilePath(attachmentData.AttachmentId, cipher.Id, temp: false);
if (!File.Exists(path))
{
return Task.FromResult((false, length));
}
length = new FileInfo(path).Length;
if (attachmentData.Size < length - leeway || attachmentData.Size > length + leeway)
{
return Task.FromResult((false, length));
}
return Task.FromResult((true, length));
}
}
}

View File

@ -2,104 +2,105 @@
using Bit.Core.Enums;
using Bit.Core.Settings;
namespace Bit.Core.Services;
public class LocalSendStorageService : ISendFileStorageService
namespace Bit.Core.Services
{
private readonly string _baseDirPath;
private readonly string _baseSendUrl;
private string RelativeFilePath(Send send, string fileID) => $"{send.Id}/{fileID}";
private string FilePath(Send send, string fileID) => $"{_baseDirPath}/{RelativeFilePath(send, fileID)}";
public FileUploadType FileUploadType => FileUploadType.Direct;
public LocalSendStorageService(
GlobalSettings globalSettings)
public class LocalSendStorageService : ISendFileStorageService
{
_baseDirPath = globalSettings.Send.BaseDirectory;
_baseSendUrl = globalSettings.Send.BaseUrl;
}
private readonly string _baseDirPath;
private readonly string _baseSendUrl;
public async Task UploadNewFileAsync(Stream stream, Send send, string fileId)
{
await InitAsync();
var path = FilePath(send, fileId);
Directory.CreateDirectory(Path.GetDirectoryName(path));
using (var fs = File.Create(path))
private string RelativeFilePath(Send send, string fileID) => $"{send.Id}/{fileID}";
private string FilePath(Send send, string fileID) => $"{_baseDirPath}/{RelativeFilePath(send, fileID)}";
public FileUploadType FileUploadType => FileUploadType.Direct;
public LocalSendStorageService(
GlobalSettings globalSettings)
{
stream.Seek(0, SeekOrigin.Begin);
await stream.CopyToAsync(fs);
}
}
public async Task DeleteFileAsync(Send send, string fileId)
{
await InitAsync();
var path = FilePath(send, fileId);
DeleteFileIfExists(path);
DeleteDirectoryIfExistsAndEmpty(Path.GetDirectoryName(path));
}
public async Task DeleteFilesForOrganizationAsync(Guid organizationId)
{
await InitAsync();
}
public async Task DeleteFilesForUserAsync(Guid userId)
{
await InitAsync();
}
public async Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId)
{
await InitAsync();
return $"{_baseSendUrl}/{RelativeFilePath(send, fileId)}";
}
private void DeleteFileIfExists(string path)
{
if (File.Exists(path))
{
File.Delete(path);
}
}
private void DeleteDirectoryIfExistsAndEmpty(string path)
{
if (Directory.Exists(path) && !Directory.EnumerateFiles(path).Any())
{
Directory.Delete(path);
}
}
private Task InitAsync()
{
if (!Directory.Exists(_baseDirPath))
{
Directory.CreateDirectory(_baseDirPath);
_baseDirPath = globalSettings.Send.BaseDirectory;
_baseSendUrl = globalSettings.Send.BaseUrl;
}
return Task.FromResult(0);
}
public Task<string> GetSendFileUploadUrlAsync(Send send, string fileId)
=> Task.FromResult($"/sends/{send.Id}/file/{fileId}");
public Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway)
{
long? length = null;
var path = FilePath(send, fileId);
if (!File.Exists(path))
public async Task UploadNewFileAsync(Stream stream, Send send, string fileId)
{
return Task.FromResult((false, length));
await InitAsync();
var path = FilePath(send, fileId);
Directory.CreateDirectory(Path.GetDirectoryName(path));
using (var fs = File.Create(path))
{
stream.Seek(0, SeekOrigin.Begin);
await stream.CopyToAsync(fs);
}
}
length = new FileInfo(path).Length;
if (expectedFileSize < length - leeway || expectedFileSize > length + leeway)
public async Task DeleteFileAsync(Send send, string fileId)
{
return Task.FromResult((false, length));
await InitAsync();
var path = FilePath(send, fileId);
DeleteFileIfExists(path);
DeleteDirectoryIfExistsAndEmpty(Path.GetDirectoryName(path));
}
return Task.FromResult((true, length));
public async Task DeleteFilesForOrganizationAsync(Guid organizationId)
{
await InitAsync();
}
public async Task DeleteFilesForUserAsync(Guid userId)
{
await InitAsync();
}
public async Task<string> GetSendFileDownloadUrlAsync(Send send, string fileId)
{
await InitAsync();
return $"{_baseSendUrl}/{RelativeFilePath(send, fileId)}";
}
private void DeleteFileIfExists(string path)
{
if (File.Exists(path))
{
File.Delete(path);
}
}
private void DeleteDirectoryIfExistsAndEmpty(string path)
{
if (Directory.Exists(path) && !Directory.EnumerateFiles(path).Any())
{
Directory.Delete(path);
}
}
private Task InitAsync()
{
if (!Directory.Exists(_baseDirPath))
{
Directory.CreateDirectory(_baseDirPath);
}
return Task.FromResult(0);
}
public Task<string> GetSendFileUploadUrlAsync(Send send, string fileId)
=> Task.FromResult($"/sends/{send.Id}/file/{fileId}");
public Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway)
{
long? length = null;
var path = FilePath(send, fileId);
if (!File.Exists(path))
{
return Task.FromResult((false, length));
}
length = new FileInfo(path).Length;
if (expectedFileSize < length - leeway || expectedFileSize > length + leeway)
{
return Task.FromResult((false, length));
}
return Task.FromResult((true, length));
}
}
}

View File

@ -4,97 +4,98 @@ using MailKit.Net.Smtp;
using Microsoft.Extensions.Logging;
using MimeKit;
namespace Bit.Core.Services;
public class MailKitSmtpMailDeliveryService : IMailDeliveryService
namespace Bit.Core.Services
{
private readonly GlobalSettings _globalSettings;
private readonly ILogger<MailKitSmtpMailDeliveryService> _logger;
private readonly string _replyDomain;
private readonly string _replyEmail;
public MailKitSmtpMailDeliveryService(
GlobalSettings globalSettings,
ILogger<MailKitSmtpMailDeliveryService> logger)
public class MailKitSmtpMailDeliveryService : IMailDeliveryService
{
if (globalSettings.Mail?.Smtp?.Host == null)
private readonly GlobalSettings _globalSettings;
private readonly ILogger<MailKitSmtpMailDeliveryService> _logger;
private readonly string _replyDomain;
private readonly string _replyEmail;
public MailKitSmtpMailDeliveryService(
GlobalSettings globalSettings,
ILogger<MailKitSmtpMailDeliveryService> logger)
{
throw new ArgumentNullException(nameof(globalSettings.Mail.Smtp.Host));
if (globalSettings.Mail?.Smtp?.Host == null)
{
throw new ArgumentNullException(nameof(globalSettings.Mail.Smtp.Host));
}
_replyEmail = CoreHelpers.PunyEncode(globalSettings.Mail?.ReplyToEmail);
if (_replyEmail.Contains("@"))
{
_replyDomain = _replyEmail.Split('@')[1];
}
_globalSettings = globalSettings;
_logger = logger;
}
_replyEmail = CoreHelpers.PunyEncode(globalSettings.Mail?.ReplyToEmail);
if (_replyEmail.Contains("@"))
public async Task SendEmailAsync(Models.Mail.MailMessage message)
{
_replyDomain = _replyEmail.Split('@')[1];
}
var mimeMessage = new MimeMessage();
mimeMessage.From.Add(new MailboxAddress(_globalSettings.SiteName, _replyEmail));
mimeMessage.Subject = message.Subject;
if (!string.IsNullOrWhiteSpace(_replyDomain))
{
mimeMessage.MessageId = $"<{Guid.NewGuid()}@{_replyDomain}>";
}
_globalSettings = globalSettings;
_logger = logger;
}
public async Task SendEmailAsync(Models.Mail.MailMessage message)
{
var mimeMessage = new MimeMessage();
mimeMessage.From.Add(new MailboxAddress(_globalSettings.SiteName, _replyEmail));
mimeMessage.Subject = message.Subject;
if (!string.IsNullOrWhiteSpace(_replyDomain))
{
mimeMessage.MessageId = $"<{Guid.NewGuid()}@{_replyDomain}>";
}
foreach (var address in message.ToEmails)
{
var punyencoded = CoreHelpers.PunyEncode(address);
mimeMessage.To.Add(MailboxAddress.Parse(punyencoded));
}
if (message.BccEmails != null)
{
foreach (var address in message.BccEmails)
foreach (var address in message.ToEmails)
{
var punyencoded = CoreHelpers.PunyEncode(address);
mimeMessage.Bcc.Add(MailboxAddress.Parse(punyencoded));
mimeMessage.To.Add(MailboxAddress.Parse(punyencoded));
}
}
var builder = new BodyBuilder();
if (!string.IsNullOrWhiteSpace(message.TextContent))
{
builder.TextBody = message.TextContent;
}
builder.HtmlBody = message.HtmlContent;
mimeMessage.Body = builder.ToMessageBody();
using (var client = new SmtpClient())
{
if (_globalSettings.Mail.Smtp.TrustServer)
if (message.BccEmails != null)
{
client.ServerCertificateValidationCallback = (s, c, h, e) => true;
foreach (var address in message.BccEmails)
{
var punyencoded = CoreHelpers.PunyEncode(address);
mimeMessage.Bcc.Add(MailboxAddress.Parse(punyencoded));
}
}
if (!_globalSettings.Mail.Smtp.StartTls && !_globalSettings.Mail.Smtp.Ssl &&
_globalSettings.Mail.Smtp.Port == 25)
var builder = new BodyBuilder();
if (!string.IsNullOrWhiteSpace(message.TextContent))
{
await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port,
MailKit.Security.SecureSocketOptions.None);
}
else
{
var useSsl = _globalSettings.Mail.Smtp.Port == 587 && !_globalSettings.Mail.Smtp.SslOverride ?
false : _globalSettings.Mail.Smtp.Ssl;
await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port, useSsl);
builder.TextBody = message.TextContent;
}
builder.HtmlBody = message.HtmlContent;
mimeMessage.Body = builder.ToMessageBody();
if (CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Username) &&
CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Password))
using (var client = new SmtpClient())
{
await client.AuthenticateAsync(_globalSettings.Mail.Smtp.Username,
_globalSettings.Mail.Smtp.Password);
}
if (_globalSettings.Mail.Smtp.TrustServer)
{
client.ServerCertificateValidationCallback = (s, c, h, e) => true;
}
await client.SendAsync(mimeMessage);
await client.DisconnectAsync(true);
if (!_globalSettings.Mail.Smtp.StartTls && !_globalSettings.Mail.Smtp.Ssl &&
_globalSettings.Mail.Smtp.Port == 25)
{
await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port,
MailKit.Security.SecureSocketOptions.None);
}
else
{
var useSsl = _globalSettings.Mail.Smtp.Port == 587 && !_globalSettings.Mail.Smtp.SslOverride ?
false : _globalSettings.Mail.Smtp.Ssl;
await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port, useSsl);
}
if (CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Username) &&
CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Password))
{
await client.AuthenticateAsync(_globalSettings.Mail.Smtp.Username,
_globalSettings.Mail.Smtp.Password);
}
await client.SendAsync(mimeMessage);
await client.DisconnectAsync(true);
}
}
}
}

View File

@ -3,39 +3,40 @@ using Bit.Core.Settings;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class MultiServiceMailDeliveryService : IMailDeliveryService
namespace Bit.Core.Services
{
private readonly IMailDeliveryService _sesService;
private readonly IMailDeliveryService _sendGridService;
private readonly int _sendGridPercentage;
private static Random _random = new Random();
public MultiServiceMailDeliveryService(
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
ILogger<AmazonSesMailDeliveryService> sesLogger,
ILogger<SendGridMailDeliveryService> sendGridLogger)
public class MultiServiceMailDeliveryService : IMailDeliveryService
{
_sesService = new AmazonSesMailDeliveryService(globalSettings, hostingEnvironment, sesLogger);
_sendGridService = new SendGridMailDeliveryService(globalSettings, hostingEnvironment, sendGridLogger);
private readonly IMailDeliveryService _sesService;
private readonly IMailDeliveryService _sendGridService;
private readonly int _sendGridPercentage;
// disabled by default (-1)
_sendGridPercentage = (globalSettings.Mail?.SendGridPercentage).GetValueOrDefault(-1);
}
private static Random _random = new Random();
public async Task SendEmailAsync(MailMessage message)
{
var roll = _random.Next(0, 99);
if (roll < _sendGridPercentage)
public MultiServiceMailDeliveryService(
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
ILogger<AmazonSesMailDeliveryService> sesLogger,
ILogger<SendGridMailDeliveryService> sendGridLogger)
{
await _sendGridService.SendEmailAsync(message);
_sesService = new AmazonSesMailDeliveryService(globalSettings, hostingEnvironment, sesLogger);
_sendGridService = new SendGridMailDeliveryService(globalSettings, hostingEnvironment, sendGridLogger);
// disabled by default (-1)
_sendGridPercentage = (globalSettings.Mail?.SendGridPercentage).GetValueOrDefault(-1);
}
else
public async Task SendEmailAsync(MailMessage message)
{
await _sesService.SendEmailAsync(message);
var roll = _random.Next(0, 99);
if (roll < _sendGridPercentage)
{
await _sendGridService.SendEmailAsync(message);
}
else
{
await _sesService.SendEmailAsync(message);
}
}
}
}

View File

@ -6,160 +6,161 @@ using Bit.Core.Utilities;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class MultiServicePushNotificationService : IPushNotificationService
namespace Bit.Core.Services
{
private readonly List<IPushNotificationService> _services = new List<IPushNotificationService>();
private readonly ILogger<MultiServicePushNotificationService> _logger;
public MultiServicePushNotificationService(
IHttpClientFactory httpFactory,
IDeviceRepository deviceRepository,
IInstallationDeviceRepository installationDeviceRepository,
GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor,
ILogger<MultiServicePushNotificationService> logger,
ILogger<RelayPushNotificationService> relayLogger,
ILogger<NotificationsApiPushNotificationService> hubLogger)
public class MultiServicePushNotificationService : IPushNotificationService
{
if (globalSettings.SelfHosted)
private readonly List<IPushNotificationService> _services = new List<IPushNotificationService>();
private readonly ILogger<MultiServicePushNotificationService> _logger;
public MultiServicePushNotificationService(
IHttpClientFactory httpFactory,
IDeviceRepository deviceRepository,
IInstallationDeviceRepository installationDeviceRepository,
GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor,
ILogger<MultiServicePushNotificationService> logger,
ILogger<RelayPushNotificationService> relayLogger,
ILogger<NotificationsApiPushNotificationService> hubLogger)
{
if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) &&
globalSettings.Installation?.Id != null &&
CoreHelpers.SettingHasValue(globalSettings.Installation?.Key))
if (globalSettings.SelfHosted)
{
_services.Add(new RelayPushNotificationService(httpFactory, deviceRepository, globalSettings,
httpContextAccessor, relayLogger));
if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) &&
globalSettings.Installation?.Id != null &&
CoreHelpers.SettingHasValue(globalSettings.Installation?.Key))
{
_services.Add(new RelayPushNotificationService(httpFactory, deviceRepository, globalSettings,
httpContextAccessor, relayLogger));
}
if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) &&
CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications))
{
_services.Add(new NotificationsApiPushNotificationService(
httpFactory, globalSettings, httpContextAccessor, hubLogger));
}
}
if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) &&
CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications))
else
{
_services.Add(new NotificationsApiPushNotificationService(
httpFactory, globalSettings, httpContextAccessor, hubLogger));
}
}
else
{
if (CoreHelpers.SettingHasValue(globalSettings.NotificationHub.ConnectionString))
{
_services.Add(new NotificationHubPushNotificationService(installationDeviceRepository,
globalSettings, httpContextAccessor));
}
if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString))
{
_services.Add(new AzureQueuePushNotificationService(globalSettings, httpContextAccessor));
if (CoreHelpers.SettingHasValue(globalSettings.NotificationHub.ConnectionString))
{
_services.Add(new NotificationHubPushNotificationService(installationDeviceRepository,
globalSettings, httpContextAccessor));
}
if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString))
{
_services.Add(new AzureQueuePushNotificationService(globalSettings, httpContextAccessor));
}
}
_logger = logger;
}
_logger = logger;
}
public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
PushToServices((s) => s.PushSyncCipherCreateAsync(cipher, collectionIds));
return Task.FromResult(0);
}
public Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
PushToServices((s) => s.PushSyncCipherUpdateAsync(cipher, collectionIds));
return Task.FromResult(0);
}
public Task PushSyncCipherDeleteAsync(Cipher cipher)
{
PushToServices((s) => s.PushSyncCipherDeleteAsync(cipher));
return Task.FromResult(0);
}
public Task PushSyncFolderCreateAsync(Folder folder)
{
PushToServices((s) => s.PushSyncFolderCreateAsync(folder));
return Task.FromResult(0);
}
public Task PushSyncFolderUpdateAsync(Folder folder)
{
PushToServices((s) => s.PushSyncFolderUpdateAsync(folder));
return Task.FromResult(0);
}
public Task PushSyncFolderDeleteAsync(Folder folder)
{
PushToServices((s) => s.PushSyncFolderDeleteAsync(folder));
return Task.FromResult(0);
}
public Task PushSyncCiphersAsync(Guid userId)
{
PushToServices((s) => s.PushSyncCiphersAsync(userId));
return Task.FromResult(0);
}
public Task PushSyncVaultAsync(Guid userId)
{
PushToServices((s) => s.PushSyncVaultAsync(userId));
return Task.FromResult(0);
}
public Task PushSyncOrgKeysAsync(Guid userId)
{
PushToServices((s) => s.PushSyncOrgKeysAsync(userId));
return Task.FromResult(0);
}
public Task PushSyncSettingsAsync(Guid userId)
{
PushToServices((s) => s.PushSyncSettingsAsync(userId));
return Task.FromResult(0);
}
public Task PushLogOutAsync(Guid userId)
{
PushToServices((s) => s.PushLogOutAsync(userId));
return Task.FromResult(0);
}
public Task PushSyncSendCreateAsync(Send send)
{
PushToServices((s) => s.PushSyncSendCreateAsync(send));
return Task.FromResult(0);
}
public Task PushSyncSendUpdateAsync(Send send)
{
PushToServices((s) => s.PushSyncSendUpdateAsync(send));
return Task.FromResult(0);
}
public Task PushSyncSendDeleteAsync(Send send)
{
PushToServices((s) => s.PushSyncSendDeleteAsync(send));
return Task.FromResult(0);
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId));
return Task.FromResult(0);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId));
return Task.FromResult(0);
}
private void PushToServices(Func<IPushNotificationService, Task> pushFunc)
{
if (_services != null)
public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
foreach (var service in _services)
PushToServices((s) => s.PushSyncCipherCreateAsync(cipher, collectionIds));
return Task.FromResult(0);
}
public Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
PushToServices((s) => s.PushSyncCipherUpdateAsync(cipher, collectionIds));
return Task.FromResult(0);
}
public Task PushSyncCipherDeleteAsync(Cipher cipher)
{
PushToServices((s) => s.PushSyncCipherDeleteAsync(cipher));
return Task.FromResult(0);
}
public Task PushSyncFolderCreateAsync(Folder folder)
{
PushToServices((s) => s.PushSyncFolderCreateAsync(folder));
return Task.FromResult(0);
}
public Task PushSyncFolderUpdateAsync(Folder folder)
{
PushToServices((s) => s.PushSyncFolderUpdateAsync(folder));
return Task.FromResult(0);
}
public Task PushSyncFolderDeleteAsync(Folder folder)
{
PushToServices((s) => s.PushSyncFolderDeleteAsync(folder));
return Task.FromResult(0);
}
public Task PushSyncCiphersAsync(Guid userId)
{
PushToServices((s) => s.PushSyncCiphersAsync(userId));
return Task.FromResult(0);
}
public Task PushSyncVaultAsync(Guid userId)
{
PushToServices((s) => s.PushSyncVaultAsync(userId));
return Task.FromResult(0);
}
public Task PushSyncOrgKeysAsync(Guid userId)
{
PushToServices((s) => s.PushSyncOrgKeysAsync(userId));
return Task.FromResult(0);
}
public Task PushSyncSettingsAsync(Guid userId)
{
PushToServices((s) => s.PushSyncSettingsAsync(userId));
return Task.FromResult(0);
}
public Task PushLogOutAsync(Guid userId)
{
PushToServices((s) => s.PushLogOutAsync(userId));
return Task.FromResult(0);
}
public Task PushSyncSendCreateAsync(Send send)
{
PushToServices((s) => s.PushSyncSendCreateAsync(send));
return Task.FromResult(0);
}
public Task PushSyncSendUpdateAsync(Send send)
{
PushToServices((s) => s.PushSyncSendUpdateAsync(send));
return Task.FromResult(0);
}
public Task PushSyncSendDeleteAsync(Send send)
{
PushToServices((s) => s.PushSyncSendDeleteAsync(send));
return Task.FromResult(0);
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId));
return Task.FromResult(0);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId));
return Task.FromResult(0);
}
private void PushToServices(Func<IPushNotificationService, Task> pushFunc)
{
if (_services != null)
{
pushFunc(service);
foreach (var service in _services)
{
pushFunc(service);
}
}
}
}

View File

@ -10,230 +10,231 @@ using Bit.Core.Settings;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.NotificationHubs;
namespace Bit.Core.Services;
public class NotificationHubPushNotificationService : IPushNotificationService
namespace Bit.Core.Services
{
private readonly IInstallationDeviceRepository _installationDeviceRepository;
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
private NotificationHubClient _client = null;
public NotificationHubPushNotificationService(
IInstallationDeviceRepository installationDeviceRepository,
GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor)
public class NotificationHubPushNotificationService : IPushNotificationService
{
_installationDeviceRepository = installationDeviceRepository;
_globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor;
_client = NotificationHubClient.CreateClientFromConnectionString(
_globalSettings.NotificationHub.ConnectionString,
_globalSettings.NotificationHub.HubName);
}
private readonly IInstallationDeviceRepository _installationDeviceRepository;
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds);
}
private NotificationHubClient _client = null;
public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds);
}
public async Task PushSyncCipherDeleteAsync(Cipher cipher)
{
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
{
if (cipher.OrganizationId.HasValue)
public NotificationHubPushNotificationService(
IInstallationDeviceRepository installationDeviceRepository,
GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor)
{
// We cannot send org pushes since access logic is much more complicated than just the fact that they belong
// to the organization. Potentially we could blindly send to just users that have the access all permission
// device registration needs to be more granular to handle that appropriately. A more brute force approach could
// me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts.
// await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true);
_installationDeviceRepository = installationDeviceRepository;
_globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor;
_client = NotificationHubClient.CreateClientFromConnectionString(
_globalSettings.NotificationHub.ConnectionString,
_globalSettings.NotificationHub.HubName);
}
else if (cipher.UserId.HasValue)
public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
var message = new SyncCipherPushNotification
await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds);
}
public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds);
}
public async Task PushSyncCipherDeleteAsync(Cipher cipher)
{
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
{
if (cipher.OrganizationId.HasValue)
{
Id = cipher.Id,
UserId = cipher.UserId,
OrganizationId = cipher.OrganizationId,
RevisionDate = cipher.RevisionDate,
CollectionIds = collectionIds,
// We cannot send org pushes since access logic is much more complicated than just the fact that they belong
// to the organization. Potentially we could blindly send to just users that have the access all permission
// device registration needs to be more granular to handle that appropriately. A more brute force approach could
// me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts.
// await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true);
}
else if (cipher.UserId.HasValue)
{
var message = new SyncCipherPushNotification
{
Id = cipher.Id,
UserId = cipher.UserId,
OrganizationId = cipher.OrganizationId,
RevisionDate = cipher.RevisionDate,
CollectionIds = collectionIds,
};
await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true);
}
}
public async Task PushSyncFolderCreateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderCreate);
}
public async Task PushSyncFolderUpdateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderUpdate);
}
public async Task PushSyncFolderDeleteAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderDelete);
}
private async Task PushFolderAsync(Folder folder, PushType type)
{
var message = new SyncFolderPushNotification
{
Id = folder.Id,
UserId = folder.UserId,
RevisionDate = folder.RevisionDate
};
await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true);
await SendPayloadToUserAsync(folder.UserId, type, message, true);
}
}
public async Task PushSyncFolderCreateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderCreate);
}
public async Task PushSyncFolderUpdateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderUpdate);
}
public async Task PushSyncFolderDeleteAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderDelete);
}
private async Task PushFolderAsync(Folder folder, PushType type)
{
var message = new SyncFolderPushNotification
public async Task PushSyncCiphersAsync(Guid userId)
{
Id = folder.Id,
UserId = folder.UserId,
RevisionDate = folder.RevisionDate
};
await PushUserAsync(userId, PushType.SyncCiphers);
}
await SendPayloadToUserAsync(folder.UserId, type, message, true);
}
public async Task PushSyncCiphersAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncCiphers);
}
public async Task PushSyncVaultAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncVault);
}
public async Task PushSyncOrgKeysAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncOrgKeys);
}
public async Task PushSyncSettingsAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncSettings);
}
public async Task PushLogOutAsync(Guid userId)
{
await PushUserAsync(userId, PushType.LogOut);
}
private async Task PushUserAsync(Guid userId, PushType type)
{
var message = new UserPushNotification
public async Task PushSyncVaultAsync(Guid userId)
{
UserId = userId,
Date = DateTime.UtcNow
};
await PushUserAsync(userId, PushType.SyncVault);
}
await SendPayloadToUserAsync(userId, type, message, false);
}
public async Task PushSyncSendCreateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendCreate);
}
public async Task PushSyncSendUpdateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendUpdate);
}
public async Task PushSyncSendDeleteAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendDelete);
}
private async Task PushSendAsync(Send send, PushType type)
{
if (send.UserId.HasValue)
public async Task PushSyncOrgKeysAsync(Guid userId)
{
var message = new SyncSendPushNotification
await PushUserAsync(userId, PushType.SyncOrgKeys);
}
public async Task PushSyncSettingsAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncSettings);
}
public async Task PushLogOutAsync(Guid userId)
{
await PushUserAsync(userId, PushType.LogOut);
}
private async Task PushUserAsync(Guid userId, PushType type)
{
var message = new UserPushNotification
{
Id = send.Id,
UserId = send.UserId.Value,
RevisionDate = send.RevisionDate
UserId = userId,
Date = DateTime.UtcNow
};
await SendPayloadToUserAsync(message.UserId, type, message, true);
await SendPayloadToUserAsync(userId, type, message, false);
}
}
private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext)
{
await SendPayloadToUserAsync(userId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext));
}
private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext)
{
await SendPayloadToUserAsync(orgId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext));
}
public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier);
await SendPayloadAsync(tag, type, payload);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
public async Task PushSyncSendCreateAsync(Send send)
{
await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId));
await PushSendAsync(send, PushType.SyncSendCreate);
}
}
public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier);
await SendPayloadAsync(tag, type, payload);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
public async Task PushSyncSendUpdateAsync(Send send)
{
await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId));
await PushSendAsync(send, PushType.SyncSendUpdate);
}
}
private string GetContextIdentifier(bool excludeCurrentContext)
{
if (!excludeCurrentContext)
public async Task PushSyncSendDeleteAsync(Send send)
{
return null;
await PushSendAsync(send, PushType.SyncSendDelete);
}
var currentContext = _httpContextAccessor?.HttpContext?.
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier;
}
private string BuildTag(string tag, string identifier)
{
if (!string.IsNullOrWhiteSpace(identifier))
private async Task PushSendAsync(Send send, PushType type)
{
tag += $" && !deviceIdentifier:{SanitizeTagInput(identifier)}";
}
return $"({tag})";
}
private async Task SendPayloadAsync(string tag, PushType type, object payload)
{
await _client.SendTemplateNotificationAsync(
new Dictionary<string, string>
if (send.UserId.HasValue)
{
{ "type", ((byte)type).ToString() },
{ "payload", JsonSerializer.Serialize(payload) }
}, tag);
}
var message = new SyncSendPushNotification
{
Id = send.Id,
UserId = send.UserId.Value,
RevisionDate = send.RevisionDate
};
private string SanitizeTagInput(string input)
{
// Only allow a-z, A-Z, 0-9, and special characters -_:
return Regex.Replace(input, "[^a-zA-Z0-9-_:]", string.Empty);
await SendPayloadToUserAsync(message.UserId, type, message, true);
}
}
private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext)
{
await SendPayloadToUserAsync(userId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext));
}
private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext)
{
await SendPayloadToUserAsync(orgId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext));
}
public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier);
await SendPayloadAsync(tag, type, payload);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{
await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId));
}
}
public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier);
await SendPayloadAsync(tag, type, payload);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{
await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId));
}
}
private string GetContextIdentifier(bool excludeCurrentContext)
{
if (!excludeCurrentContext)
{
return null;
}
var currentContext = _httpContextAccessor?.HttpContext?.
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier;
}
private string BuildTag(string tag, string identifier)
{
if (!string.IsNullOrWhiteSpace(identifier))
{
tag += $" && !deviceIdentifier:{SanitizeTagInput(identifier)}";
}
return $"({tag})";
}
private async Task SendPayloadAsync(string tag, PushType type, object payload)
{
await _client.SendTemplateNotificationAsync(
new Dictionary<string, string>
{
{ "type", ((byte)type).ToString() },
{ "payload", JsonSerializer.Serialize(payload) }
}, tag);
}
private string SanitizeTagInput(string input)
{
// Only allow a-z, A-Z, 0-9, and special characters -_:
return Regex.Replace(input, "[^a-zA-Z0-9-_:]", string.Empty);
}
}
}

View File

@ -4,191 +4,192 @@ using Bit.Core.Repositories;
using Bit.Core.Settings;
using Microsoft.Azure.NotificationHubs;
namespace Bit.Core.Services;
public class NotificationHubPushRegistrationService : IPushRegistrationService
namespace Bit.Core.Services
{
private readonly IInstallationDeviceRepository _installationDeviceRepository;
private readonly GlobalSettings _globalSettings;
private NotificationHubClient _client = null;
public NotificationHubPushRegistrationService(
IInstallationDeviceRepository installationDeviceRepository,
GlobalSettings globalSettings)
public class NotificationHubPushRegistrationService : IPushRegistrationService
{
_installationDeviceRepository = installationDeviceRepository;
_globalSettings = globalSettings;
_client = NotificationHubClient.CreateClientFromConnectionString(
_globalSettings.NotificationHub.ConnectionString,
_globalSettings.NotificationHub.HubName);
}
private readonly IInstallationDeviceRepository _installationDeviceRepository;
private readonly GlobalSettings _globalSettings;
public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId,
string identifier, DeviceType type)
{
if (string.IsNullOrWhiteSpace(pushToken))
private NotificationHubClient _client = null;
public NotificationHubPushRegistrationService(
IInstallationDeviceRepository installationDeviceRepository,
GlobalSettings globalSettings)
{
return;
_installationDeviceRepository = installationDeviceRepository;
_globalSettings = globalSettings;
_client = NotificationHubClient.CreateClientFromConnectionString(
_globalSettings.NotificationHub.ConnectionString,
_globalSettings.NotificationHub.HubName);
}
var installation = new Installation
public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId,
string identifier, DeviceType type)
{
InstallationId = deviceId,
PushChannel = pushToken,
Templates = new Dictionary<string, InstallationTemplate>()
};
installation.Tags = new List<string>
{
$"userId:{userId}"
};
if (!string.IsNullOrWhiteSpace(identifier))
{
installation.Tags.Add("deviceIdentifier:" + identifier);
}
string payloadTemplate = null, messageTemplate = null, badgeMessageTemplate = null;
switch (type)
{
case DeviceType.Android:
payloadTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}}";
messageTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\"}," +
"\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}";
installation.Platform = NotificationPlatform.Fcm;
break;
case DeviceType.iOS:
payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}," +
"\"aps\":{\"content-available\":1}}";
messageTemplate = "{\"data\":{\"type\":\"#(type)\"}," +
"\"aps\":{\"alert\":\"$(message)\",\"badge\":null,\"content-available\":1}}";
badgeMessageTemplate = "{\"data\":{\"type\":\"#(type)\"}," +
"\"aps\":{\"alert\":\"$(message)\",\"badge\":\"#(badge)\",\"content-available\":1}}";
installation.Platform = NotificationPlatform.Apns;
break;
case DeviceType.AndroidAmazon:
payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}";
messageTemplate = "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}";
installation.Platform = NotificationPlatform.Adm;
break;
default:
break;
}
BuildInstallationTemplate(installation, "payload", payloadTemplate, userId, identifier);
BuildInstallationTemplate(installation, "message", messageTemplate, userId, identifier);
BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate,
userId, identifier);
await _client.CreateOrUpdateInstallationAsync(installation);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{
await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId));
}
}
private void BuildInstallationTemplate(Installation installation, string templateId, string templateBody,
string userId, string identifier)
{
if (templateBody == null)
{
return;
}
var fullTemplateId = $"template:{templateId}";
var template = new InstallationTemplate
{
Body = templateBody,
Tags = new List<string>
if (string.IsNullOrWhiteSpace(pushToken))
{
fullTemplateId,
$"{fullTemplateId}_userId:{userId}"
return;
}
};
if (!string.IsNullOrWhiteSpace(identifier))
{
template.Tags.Add($"{fullTemplateId}_deviceIdentifier:{identifier}");
}
var installation = new Installation
{
InstallationId = deviceId,
PushChannel = pushToken,
Templates = new Dictionary<string, InstallationTemplate>()
};
installation.Templates.Add(fullTemplateId, template);
}
installation.Tags = new List<string>
{
$"userId:{userId}"
};
public async Task DeleteRegistrationAsync(string deviceId)
{
try
{
await _client.DeleteInstallationAsync(deviceId);
if (!string.IsNullOrWhiteSpace(identifier))
{
installation.Tags.Add("deviceIdentifier:" + identifier);
}
string payloadTemplate = null, messageTemplate = null, badgeMessageTemplate = null;
switch (type)
{
case DeviceType.Android:
payloadTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}}";
messageTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\"}," +
"\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}";
installation.Platform = NotificationPlatform.Fcm;
break;
case DeviceType.iOS:
payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}," +
"\"aps\":{\"content-available\":1}}";
messageTemplate = "{\"data\":{\"type\":\"#(type)\"}," +
"\"aps\":{\"alert\":\"$(message)\",\"badge\":null,\"content-available\":1}}";
badgeMessageTemplate = "{\"data\":{\"type\":\"#(type)\"}," +
"\"aps\":{\"alert\":\"$(message)\",\"badge\":\"#(badge)\",\"content-available\":1}}";
installation.Platform = NotificationPlatform.Apns;
break;
case DeviceType.AndroidAmazon:
payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}";
messageTemplate = "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}";
installation.Platform = NotificationPlatform.Adm;
break;
default:
break;
}
BuildInstallationTemplate(installation, "payload", payloadTemplate, userId, identifier);
BuildInstallationTemplate(installation, "message", messageTemplate, userId, identifier);
BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate,
userId, identifier);
await _client.CreateOrUpdateInstallationAsync(installation);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{
await _installationDeviceRepository.DeleteAsync(new InstallationDeviceEntity(deviceId));
await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId));
}
}
catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found"))
{
throw;
}
}
public async Task AddUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
{
await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Add, $"organizationId:{organizationId}");
if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First()))
private void BuildInstallationTemplate(Installation installation, string templateId, string templateBody,
string userId, string identifier)
{
var entities = deviceIds.Select(e => new InstallationDeviceEntity(e));
await _installationDeviceRepository.UpsertManyAsync(entities.ToList());
}
}
if (templateBody == null)
{
return;
}
public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
{
await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Remove,
$"organizationId:{organizationId}");
if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First()))
{
var entities = deviceIds.Select(e => new InstallationDeviceEntity(e));
await _installationDeviceRepository.UpsertManyAsync(entities.ToList());
}
}
var fullTemplateId = $"template:{templateId}";
private async Task PatchTagsForUserDevicesAsync(IEnumerable<string> deviceIds, UpdateOperationType op,
string tag)
{
if (!deviceIds.Any())
{
return;
var template = new InstallationTemplate
{
Body = templateBody,
Tags = new List<string>
{
fullTemplateId,
$"{fullTemplateId}_userId:{userId}"
}
};
if (!string.IsNullOrWhiteSpace(identifier))
{
template.Tags.Add($"{fullTemplateId}_deviceIdentifier:{identifier}");
}
installation.Templates.Add(fullTemplateId, template);
}
var operation = new PartialUpdateOperation
{
Operation = op,
Path = "/tags"
};
if (op == UpdateOperationType.Add)
{
operation.Value = tag;
}
else if (op == UpdateOperationType.Remove)
{
operation.Path += $"/{tag}";
}
foreach (var id in deviceIds)
public async Task DeleteRegistrationAsync(string deviceId)
{
try
{
await _client.PatchInstallationAsync(id, new List<PartialUpdateOperation> { operation });
await _client.DeleteInstallationAsync(deviceId);
if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId))
{
await _installationDeviceRepository.DeleteAsync(new InstallationDeviceEntity(deviceId));
}
}
catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found"))
{
throw;
}
}
public async Task AddUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
{
await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Add, $"organizationId:{organizationId}");
if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First()))
{
var entities = deviceIds.Select(e => new InstallationDeviceEntity(e));
await _installationDeviceRepository.UpsertManyAsync(entities.ToList());
}
}
public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
{
await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Remove,
$"organizationId:{organizationId}");
if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First()))
{
var entities = deviceIds.Select(e => new InstallationDeviceEntity(e));
await _installationDeviceRepository.UpsertManyAsync(entities.ToList());
}
}
private async Task PatchTagsForUserDevicesAsync(IEnumerable<string> deviceIds, UpdateOperationType op,
string tag)
{
if (!deviceIds.Any())
{
return;
}
var operation = new PartialUpdateOperation
{
Operation = op,
Path = "/tags"
};
if (op == UpdateOperationType.Add)
{
operation.Value = tag;
}
else if (op == UpdateOperationType.Remove)
{
operation.Path += $"/{tag}";
}
foreach (var id in deviceIds)
{
try
{
await _client.PatchInstallationAsync(id, new List<PartialUpdateOperation> { operation });
}
catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found"))
{
throw;
}
}
}
}
}

View File

@ -6,197 +6,198 @@ using Bit.Core.Settings;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class NotificationsApiPushNotificationService : BaseIdentityClientService, IPushNotificationService
namespace Bit.Core.Services
{
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
public NotificationsApiPushNotificationService(
IHttpClientFactory httpFactory,
GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor,
ILogger<NotificationsApiPushNotificationService> logger)
: base(
httpFactory,
globalSettings.BaseServiceUri.InternalNotifications,
globalSettings.BaseServiceUri.InternalIdentity,
"internal",
$"internal.{globalSettings.ProjectName}",
globalSettings.InternalIdentityKey,
logger)
public class NotificationsApiPushNotificationService : BaseIdentityClientService, IPushNotificationService
{
_globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor;
}
private readonly GlobalSettings _globalSettings;
private readonly IHttpContextAccessor _httpContextAccessor;
public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds);
}
public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds);
}
public async Task PushSyncCipherDeleteAsync(Cipher cipher)
{
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
{
if (cipher.OrganizationId.HasValue)
public NotificationsApiPushNotificationService(
IHttpClientFactory httpFactory,
GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor,
ILogger<NotificationsApiPushNotificationService> logger)
: base(
httpFactory,
globalSettings.BaseServiceUri.InternalNotifications,
globalSettings.BaseServiceUri.InternalIdentity,
"internal",
$"internal.{globalSettings.ProjectName}",
globalSettings.InternalIdentityKey,
logger)
{
var message = new SyncCipherPushNotification
_globalSettings = globalSettings;
_httpContextAccessor = httpContextAccessor;
}
public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds);
}
public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds);
}
public async Task PushSyncCipherDeleteAsync(Cipher cipher)
{
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
{
if (cipher.OrganizationId.HasValue)
{
Id = cipher.Id,
OrganizationId = cipher.OrganizationId,
RevisionDate = cipher.RevisionDate,
CollectionIds = collectionIds,
var message = new SyncCipherPushNotification
{
Id = cipher.Id,
OrganizationId = cipher.OrganizationId,
RevisionDate = cipher.RevisionDate,
CollectionIds = collectionIds,
};
await SendMessageAsync(type, message, true);
}
else if (cipher.UserId.HasValue)
{
var message = new SyncCipherPushNotification
{
Id = cipher.Id,
UserId = cipher.UserId,
RevisionDate = cipher.RevisionDate,
CollectionIds = collectionIds,
};
await SendMessageAsync(type, message, true);
}
}
public async Task PushSyncFolderCreateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderCreate);
}
public async Task PushSyncFolderUpdateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderUpdate);
}
public async Task PushSyncFolderDeleteAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderDelete);
}
private async Task PushFolderAsync(Folder folder, PushType type)
{
var message = new SyncFolderPushNotification
{
Id = folder.Id,
UserId = folder.UserId,
RevisionDate = folder.RevisionDate
};
await SendMessageAsync(type, message, true);
}
else if (cipher.UserId.HasValue)
{
var message = new SyncCipherPushNotification
{
Id = cipher.Id,
UserId = cipher.UserId,
RevisionDate = cipher.RevisionDate,
CollectionIds = collectionIds,
};
await SendMessageAsync(type, message, true);
public async Task PushSyncCiphersAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncCiphers);
}
}
public async Task PushSyncFolderCreateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderCreate);
}
public async Task PushSyncFolderUpdateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderUpdate);
}
public async Task PushSyncFolderDeleteAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderDelete);
}
private async Task PushFolderAsync(Folder folder, PushType type)
{
var message = new SyncFolderPushNotification
public async Task PushSyncVaultAsync(Guid userId)
{
Id = folder.Id,
UserId = folder.UserId,
RevisionDate = folder.RevisionDate
};
await PushUserAsync(userId, PushType.SyncVault);
}
await SendMessageAsync(type, message, true);
}
public async Task PushSyncCiphersAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncCiphers);
}
public async Task PushSyncVaultAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncVault);
}
public async Task PushSyncOrgKeysAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncOrgKeys);
}
public async Task PushSyncSettingsAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncSettings);
}
public async Task PushLogOutAsync(Guid userId)
{
await PushUserAsync(userId, PushType.LogOut);
}
private async Task PushUserAsync(Guid userId, PushType type)
{
var message = new UserPushNotification
public async Task PushSyncOrgKeysAsync(Guid userId)
{
UserId = userId,
Date = DateTime.UtcNow
};
await PushUserAsync(userId, PushType.SyncOrgKeys);
}
await SendMessageAsync(type, message, false);
}
public async Task PushSyncSendCreateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendCreate);
}
public async Task PushSyncSendUpdateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendUpdate);
}
public async Task PushSyncSendDeleteAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendDelete);
}
private async Task PushSendAsync(Send send, PushType type)
{
if (send.UserId.HasValue)
public async Task PushSyncSettingsAsync(Guid userId)
{
var message = new SyncSendPushNotification
await PushUserAsync(userId, PushType.SyncSettings);
}
public async Task PushLogOutAsync(Guid userId)
{
await PushUserAsync(userId, PushType.LogOut);
}
private async Task PushUserAsync(Guid userId, PushType type)
{
var message = new UserPushNotification
{
Id = send.Id,
UserId = send.UserId.Value,
RevisionDate = send.RevisionDate
UserId = userId,
Date = DateTime.UtcNow
};
await SendMessageAsync(type, message, false);
}
}
private async Task SendMessageAsync<T>(PushType type, T payload, bool excludeCurrentContext)
{
var contextId = GetContextIdentifier(excludeCurrentContext);
var request = new PushNotificationData<T>(type, payload, contextId);
await SendAsync(HttpMethod.Post, "send", request);
}
private string GetContextIdentifier(bool excludeCurrentContext)
{
if (!excludeCurrentContext)
public async Task PushSyncSendCreateAsync(Send send)
{
return null;
await PushSendAsync(send, PushType.SyncSendCreate);
}
var currentContext = _httpContextAccessor?.HttpContext?.
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier;
}
public async Task PushSyncSendUpdateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendUpdate);
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
// Noop
return Task.FromResult(0);
}
public async Task PushSyncSendDeleteAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendDelete);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
// Noop
return Task.FromResult(0);
private async Task PushSendAsync(Send send, PushType type)
{
if (send.UserId.HasValue)
{
var message = new SyncSendPushNotification
{
Id = send.Id,
UserId = send.UserId.Value,
RevisionDate = send.RevisionDate
};
await SendMessageAsync(type, message, false);
}
}
private async Task SendMessageAsync<T>(PushType type, T payload, bool excludeCurrentContext)
{
var contextId = GetContextIdentifier(excludeCurrentContext);
var request = new PushNotificationData<T>(type, payload, contextId);
await SendAsync(HttpMethod.Post, "send", request);
}
private string GetContextIdentifier(bool excludeCurrentContext)
{
if (!excludeCurrentContext)
{
return null;
}
var currentContext = _httpContextAccessor?.HttpContext?.
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
return currentContext?.DeviceIdentifier;
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
// Noop
return Task.FromResult(0);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
// Noop
return Task.FromResult(0);
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -3,169 +3,170 @@ using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
namespace Bit.Core.Services;
public class PolicyService : IPolicyService
namespace Bit.Core.Services
{
private readonly IEventService _eventService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IPolicyRepository _policyRepository;
private readonly ISsoConfigRepository _ssoConfigRepository;
private readonly IMailService _mailService;
public PolicyService(
IEventService eventService,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IPolicyRepository policyRepository,
ISsoConfigRepository ssoConfigRepository,
IMailService mailService)
public class PolicyService : IPolicyService
{
_eventService = eventService;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_policyRepository = policyRepository;
_ssoConfigRepository = ssoConfigRepository;
_mailService = mailService;
}
private readonly IEventService _eventService;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IPolicyRepository _policyRepository;
private readonly ISsoConfigRepository _ssoConfigRepository;
private readonly IMailService _mailService;
public async Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService,
Guid? savingUserId)
{
var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId);
if (org == null)
public PolicyService(
IEventService eventService,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IPolicyRepository policyRepository,
ISsoConfigRepository ssoConfigRepository,
IMailService mailService)
{
throw new BadRequestException("Organization not found");
_eventService = eventService;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_policyRepository = policyRepository;
_ssoConfigRepository = ssoConfigRepository;
_mailService = mailService;
}
if (!org.UsePolicies)
public async Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService,
Guid? savingUserId)
{
throw new BadRequestException("This organization cannot use policies.");
}
// Handle dependent policy checks
switch (policy.Type)
{
case PolicyType.SingleOrg:
if (!policy.Enabled)
{
await RequiredBySsoAsync(org);
await RequiredByVaultTimeoutAsync(org);
await RequiredByKeyConnectorAsync(org);
}
break;
case PolicyType.RequireSso:
if (policy.Enabled)
{
await DependsOnSingleOrgAsync(org);
}
else
{
await RequiredByKeyConnectorAsync(org);
}
break;
case PolicyType.MaximumVaultTimeout:
if (policy.Enabled)
{
await DependsOnSingleOrgAsync(org);
}
break;
}
var now = DateTime.UtcNow;
if (policy.Id == default(Guid))
{
policy.CreationDate = now;
}
if (policy.Enabled)
{
var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id);
if (!currentPolicy?.Enabled ?? true)
var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId);
if (org == null)
{
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(
policy.OrganizationId);
var removableOrgUsers = orgUsers.Where(ou =>
ou.Status != Enums.OrganizationUserStatusType.Invited &&
ou.Type != Enums.OrganizationUserType.Owner && ou.Type != Enums.OrganizationUserType.Admin &&
ou.UserId != savingUserId);
switch (policy.Type)
throw new BadRequestException("Organization not found");
}
if (!org.UsePolicies)
{
throw new BadRequestException("This organization cannot use policies.");
}
// Handle dependent policy checks
switch (policy.Type)
{
case PolicyType.SingleOrg:
if (!policy.Enabled)
{
await RequiredBySsoAsync(org);
await RequiredByVaultTimeoutAsync(org);
await RequiredByKeyConnectorAsync(org);
}
break;
case PolicyType.RequireSso:
if (policy.Enabled)
{
await DependsOnSingleOrgAsync(org);
}
else
{
await RequiredByKeyConnectorAsync(org);
}
break;
case PolicyType.MaximumVaultTimeout:
if (policy.Enabled)
{
await DependsOnSingleOrgAsync(org);
}
break;
}
var now = DateTime.UtcNow;
if (policy.Id == default(Guid))
{
policy.CreationDate = now;
}
if (policy.Enabled)
{
var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id);
if (!currentPolicy?.Enabled ?? true)
{
case Enums.PolicyType.TwoFactorAuthentication:
foreach (var orgUser in removableOrgUsers)
{
if (!await userService.TwoFactorIsEnabledAsync(orgUser))
var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(
policy.OrganizationId);
var removableOrgUsers = orgUsers.Where(ou =>
ou.Status != Enums.OrganizationUserStatusType.Invited &&
ou.Type != Enums.OrganizationUserType.Owner && ou.Type != Enums.OrganizationUserType.Admin &&
ou.UserId != savingUserId);
switch (policy.Type)
{
case Enums.PolicyType.TwoFactorAuthentication:
foreach (var orgUser in removableOrgUsers)
{
await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id,
savingUserId);
await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(
org.Name, orgUser.Email);
if (!await userService.TwoFactorIsEnabledAsync(orgUser))
{
await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id,
savingUserId);
await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(
org.Name, orgUser.Email);
}
}
}
break;
case Enums.PolicyType.SingleOrg:
var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync(
removableOrgUsers.Select(ou => ou.UserId.Value));
foreach (var orgUser in removableOrgUsers)
{
if (userOrgs.Any(ou => ou.UserId == orgUser.UserId
&& ou.OrganizationId != org.Id
&& ou.Status != OrganizationUserStatusType.Invited))
break;
case Enums.PolicyType.SingleOrg:
var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync(
removableOrgUsers.Select(ou => ou.UserId.Value));
foreach (var orgUser in removableOrgUsers)
{
await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id,
savingUserId);
await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(
org.Name, orgUser.Email);
if (userOrgs.Any(ou => ou.UserId == orgUser.UserId
&& ou.OrganizationId != org.Id
&& ou.Status != OrganizationUserStatusType.Invited))
{
await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id,
savingUserId);
await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(
org.Name, orgUser.Email);
}
}
}
break;
default:
break;
break;
default:
break;
}
}
}
policy.RevisionDate = now;
await _policyRepository.UpsertAsync(policy);
await _eventService.LogPolicyEventAsync(policy, Enums.EventType.Policy_Updated);
}
policy.RevisionDate = now;
await _policyRepository.UpsertAsync(policy);
await _eventService.LogPolicyEventAsync(policy, Enums.EventType.Policy_Updated);
}
private async Task DependsOnSingleOrgAsync(Organization org)
{
var singleOrg = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.SingleOrg);
if (singleOrg?.Enabled != true)
private async Task DependsOnSingleOrgAsync(Organization org)
{
throw new BadRequestException("Single Organization policy not enabled.");
var singleOrg = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.SingleOrg);
if (singleOrg?.Enabled != true)
{
throw new BadRequestException("Single Organization policy not enabled.");
}
}
}
private async Task RequiredBySsoAsync(Organization org)
{
var requireSso = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.RequireSso);
if (requireSso?.Enabled == true)
private async Task RequiredBySsoAsync(Organization org)
{
throw new BadRequestException("Single Sign-On Authentication policy is enabled.");
var requireSso = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.RequireSso);
if (requireSso?.Enabled == true)
{
throw new BadRequestException("Single Sign-On Authentication policy is enabled.");
}
}
}
private async Task RequiredByKeyConnectorAsync(Organization org)
{
var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(org.Id);
if (ssoConfig?.GetData()?.KeyConnectorEnabled == true)
private async Task RequiredByKeyConnectorAsync(Organization org)
{
throw new BadRequestException("Key Connector is enabled.");
var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(org.Id);
if (ssoConfig?.GetData()?.KeyConnectorEnabled == true)
{
throw new BadRequestException("Key Connector is enabled.");
}
}
}
private async Task RequiredByVaultTimeoutAsync(Organization org)
{
var vaultTimeout = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.MaximumVaultTimeout);
if (vaultTimeout?.Enabled == true)
private async Task RequiredByVaultTimeoutAsync(Organization org)
{
throw new BadRequestException("Maximum Vault Timeout policy is enabled.");
var vaultTimeout = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.MaximumVaultTimeout);
if (vaultTimeout?.Enabled == true)
{
throw new BadRequestException("Maximum Vault Timeout policy is enabled.");
}
}
}
}

View File

@ -8,218 +8,219 @@ using Bit.Core.Settings;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class RelayPushNotificationService : BaseIdentityClientService, IPushNotificationService
namespace Bit.Core.Services
{
private readonly IDeviceRepository _deviceRepository;
private readonly IHttpContextAccessor _httpContextAccessor;
public RelayPushNotificationService(
IHttpClientFactory httpFactory,
IDeviceRepository deviceRepository,
GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor,
ILogger<RelayPushNotificationService> logger)
: base(
httpFactory,
globalSettings.PushRelayBaseUri,
globalSettings.Installation.IdentityUri,
"api.push",
$"installation.{globalSettings.Installation.Id}",
globalSettings.Installation.Key,
logger)
public class RelayPushNotificationService : BaseIdentityClientService, IPushNotificationService
{
_deviceRepository = deviceRepository;
_httpContextAccessor = httpContextAccessor;
}
private readonly IDeviceRepository _deviceRepository;
private readonly IHttpContextAccessor _httpContextAccessor;
public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds);
}
public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds);
}
public async Task PushSyncCipherDeleteAsync(Cipher cipher)
{
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
{
if (cipher.OrganizationId.HasValue)
public RelayPushNotificationService(
IHttpClientFactory httpFactory,
IDeviceRepository deviceRepository,
GlobalSettings globalSettings,
IHttpContextAccessor httpContextAccessor,
ILogger<RelayPushNotificationService> logger)
: base(
httpFactory,
globalSettings.PushRelayBaseUri,
globalSettings.Installation.IdentityUri,
"api.push",
$"installation.{globalSettings.Installation.Id}",
globalSettings.Installation.Key,
logger)
{
// We cannot send org pushes since access logic is much more complicated than just the fact that they belong
// to the organization. Potentially we could blindly send to just users that have the access all permission
// device registration needs to be more granular to handle that appropriately. A more brute force approach could
// me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts.
// await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true);
_deviceRepository = deviceRepository;
_httpContextAccessor = httpContextAccessor;
}
else if (cipher.UserId.HasValue)
{
var message = new SyncCipherPushNotification
{
Id = cipher.Id,
UserId = cipher.UserId,
OrganizationId = cipher.OrganizationId,
RevisionDate = cipher.RevisionDate,
};
await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true);
public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds);
}
}
public async Task PushSyncFolderCreateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderCreate);
}
public async Task PushSyncFolderUpdateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderUpdate);
}
public async Task PushSyncFolderDeleteAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderDelete);
}
private async Task PushFolderAsync(Folder folder, PushType type)
{
var message = new SyncFolderPushNotification
public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable<Guid> collectionIds)
{
Id = folder.Id,
UserId = folder.UserId,
RevisionDate = folder.RevisionDate
};
await SendPayloadToUserAsync(folder.UserId, type, message, true);
}
public async Task PushSyncCiphersAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncCiphers);
}
public async Task PushSyncVaultAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncVault);
}
public async Task PushSyncOrgKeysAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncOrgKeys);
}
public async Task PushSyncSettingsAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncSettings);
}
public async Task PushLogOutAsync(Guid userId)
{
await PushUserAsync(userId, PushType.LogOut);
}
private async Task PushUserAsync(Guid userId, PushType type)
{
var message = new UserPushNotification
{
UserId = userId,
Date = DateTime.UtcNow
};
await SendPayloadToUserAsync(userId, type, message, false);
}
public async Task PushSyncSendCreateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendCreate);
}
public async Task PushSyncSendUpdateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendUpdate);
}
public async Task PushSyncSendDeleteAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendDelete);
}
private async Task PushSendAsync(Send send, PushType type)
{
if (send.UserId.HasValue)
{
var message = new SyncSendPushNotification
{
Id = send.Id,
UserId = send.UserId.Value,
RevisionDate = send.RevisionDate
};
await SendPayloadToUserAsync(message.UserId, type, message, true);
await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds);
}
}
private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext)
{
var request = new PushSendRequestModel
public async Task PushSyncCipherDeleteAsync(Cipher cipher)
{
UserId = userId.ToString(),
Type = type,
Payload = payload
};
await PushCipherAsync(cipher, PushType.SyncLoginDelete, null);
}
await AddCurrentContextAsync(request, excludeCurrentContext);
await SendAsync(HttpMethod.Post, "push/send", request);
}
private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext)
{
var request = new PushSendRequestModel
private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable<Guid> collectionIds)
{
OrganizationId = orgId.ToString(),
Type = type,
Payload = payload
};
await AddCurrentContextAsync(request, excludeCurrentContext);
await SendAsync(HttpMethod.Post, "push/send", request);
}
private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier)
{
var currentContext = _httpContextAccessor?.HttpContext?.
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier))
{
var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier);
if (device != null)
if (cipher.OrganizationId.HasValue)
{
request.DeviceId = device.Id.ToString();
// We cannot send org pushes since access logic is much more complicated than just the fact that they belong
// to the organization. Potentially we could blindly send to just users that have the access all permission
// device registration needs to be more granular to handle that appropriately. A more brute force approach could
// me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts.
// await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true);
}
if (addIdentifier)
else if (cipher.UserId.HasValue)
{
request.Identifier = currentContext.DeviceIdentifier;
var message = new SyncCipherPushNotification
{
Id = cipher.Id,
UserId = cipher.UserId,
OrganizationId = cipher.OrganizationId,
RevisionDate = cipher.RevisionDate,
};
await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true);
}
}
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
throw new NotImplementedException();
}
public async Task PushSyncFolderCreateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderCreate);
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
throw new NotImplementedException();
public async Task PushSyncFolderUpdateAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderUpdate);
}
public async Task PushSyncFolderDeleteAsync(Folder folder)
{
await PushFolderAsync(folder, PushType.SyncFolderDelete);
}
private async Task PushFolderAsync(Folder folder, PushType type)
{
var message = new SyncFolderPushNotification
{
Id = folder.Id,
UserId = folder.UserId,
RevisionDate = folder.RevisionDate
};
await SendPayloadToUserAsync(folder.UserId, type, message, true);
}
public async Task PushSyncCiphersAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncCiphers);
}
public async Task PushSyncVaultAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncVault);
}
public async Task PushSyncOrgKeysAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncOrgKeys);
}
public async Task PushSyncSettingsAsync(Guid userId)
{
await PushUserAsync(userId, PushType.SyncSettings);
}
public async Task PushLogOutAsync(Guid userId)
{
await PushUserAsync(userId, PushType.LogOut);
}
private async Task PushUserAsync(Guid userId, PushType type)
{
var message = new UserPushNotification
{
UserId = userId,
Date = DateTime.UtcNow
};
await SendPayloadToUserAsync(userId, type, message, false);
}
public async Task PushSyncSendCreateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendCreate);
}
public async Task PushSyncSendUpdateAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendUpdate);
}
public async Task PushSyncSendDeleteAsync(Send send)
{
await PushSendAsync(send, PushType.SyncSendDelete);
}
private async Task PushSendAsync(Send send, PushType type)
{
if (send.UserId.HasValue)
{
var message = new SyncSendPushNotification
{
Id = send.Id,
UserId = send.UserId.Value,
RevisionDate = send.RevisionDate
};
await SendPayloadToUserAsync(message.UserId, type, message, true);
}
}
private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext)
{
var request = new PushSendRequestModel
{
UserId = userId.ToString(),
Type = type,
Payload = payload
};
await AddCurrentContextAsync(request, excludeCurrentContext);
await SendAsync(HttpMethod.Post, "push/send", request);
}
private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext)
{
var request = new PushSendRequestModel
{
OrganizationId = orgId.ToString(),
Type = type,
Payload = payload
};
await AddCurrentContextAsync(request, excludeCurrentContext);
await SendAsync(HttpMethod.Post, "push/send", request);
}
private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier)
{
var currentContext = _httpContextAccessor?.HttpContext?.
RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext;
if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier))
{
var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier);
if (device != null)
{
request.DeviceId = device.Id.ToString();
}
if (addIdentifier)
{
request.Identifier = currentContext.DeviceIdentifier;
}
}
}
public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier,
string deviceId = null)
{
throw new NotImplementedException();
}
public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier,
string deviceId = null)
{
throw new NotImplementedException();
}
}
}

View File

@ -3,64 +3,65 @@ using Bit.Core.Models.Api;
using Bit.Core.Settings;
using Microsoft.Extensions.Logging;
namespace Bit.Core.Services;
public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegistrationService
namespace Bit.Core.Services
{
public RelayPushRegistrationService(
IHttpClientFactory httpFactory,
GlobalSettings globalSettings,
ILogger<RelayPushRegistrationService> logger)
: base(
httpFactory,
globalSettings.PushRelayBaseUri,
globalSettings.Installation.IdentityUri,
"api.push",
$"installation.{globalSettings.Installation.Id}",
globalSettings.Installation.Key,
logger)
public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegistrationService
{
}
public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId,
string identifier, DeviceType type)
{
var requestModel = new PushRegistrationRequestModel
public RelayPushRegistrationService(
IHttpClientFactory httpFactory,
GlobalSettings globalSettings,
ILogger<RelayPushRegistrationService> logger)
: base(
httpFactory,
globalSettings.PushRelayBaseUri,
globalSettings.Installation.IdentityUri,
"api.push",
$"installation.{globalSettings.Installation.Id}",
globalSettings.Installation.Key,
logger)
{
DeviceId = deviceId,
Identifier = identifier,
PushToken = pushToken,
Type = type,
UserId = userId
};
await SendAsync(HttpMethod.Post, "push/register", requestModel);
}
public async Task DeleteRegistrationAsync(string deviceId)
{
await SendAsync(HttpMethod.Delete, string.Concat("push/", deviceId));
}
public async Task AddUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
{
if (!deviceIds.Any())
{
return;
}
var requestModel = new PushUpdateRequestModel(deviceIds, organizationId);
await SendAsync(HttpMethod.Put, "push/add-organization", requestModel);
}
public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
{
if (!deviceIds.Any())
public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId,
string identifier, DeviceType type)
{
return;
var requestModel = new PushRegistrationRequestModel
{
DeviceId = deviceId,
Identifier = identifier,
PushToken = pushToken,
Type = type,
UserId = userId
};
await SendAsync(HttpMethod.Post, "push/register", requestModel);
}
var requestModel = new PushUpdateRequestModel(deviceIds, organizationId);
await SendAsync(HttpMethod.Put, "push/delete-organization", requestModel);
public async Task DeleteRegistrationAsync(string deviceId)
{
await SendAsync(HttpMethod.Delete, string.Concat("push/", deviceId));
}
public async Task AddUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
{
if (!deviceIds.Any())
{
return;
}
var requestModel = new PushUpdateRequestModel(deviceIds, organizationId);
await SendAsync(HttpMethod.Put, "push/add-organization", requestModel);
}
public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable<string> deviceIds, string organizationId)
{
if (!deviceIds.Any())
{
return;
}
var requestModel = new PushUpdateRequestModel(deviceIds, organizationId);
await SendAsync(HttpMethod.Put, "push/delete-organization", requestModel);
}
}
}

View File

@ -1,25 +1,26 @@
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
namespace Bit.Core.Services;
public class RepositoryEventWriteService : IEventWriteService
namespace Bit.Core.Services
{
private readonly IEventRepository _eventRepository;
public RepositoryEventWriteService(
IEventRepository eventRepository)
public class RepositoryEventWriteService : IEventWriteService
{
_eventRepository = eventRepository;
}
private readonly IEventRepository _eventRepository;
public async Task CreateAsync(IEvent e)
{
await _eventRepository.CreateAsync(e);
}
public RepositoryEventWriteService(
IEventRepository eventRepository)
{
_eventRepository = eventRepository;
}
public async Task CreateManyAsync(IEnumerable<IEvent> e)
{
await _eventRepository.CreateManyAsync(e);
public async Task CreateAsync(IEvent e)
{
await _eventRepository.CreateAsync(e);
}
public async Task CreateManyAsync(IEnumerable<IEvent> e)
{
await _eventRepository.CreateManyAsync(e);
}
}
}

View File

@ -6,109 +6,110 @@ using Microsoft.Extensions.Logging;
using SendGrid;
using SendGrid.Helpers.Mail;
namespace Bit.Core.Services;
public class SendGridMailDeliveryService : IMailDeliveryService, IDisposable
namespace Bit.Core.Services
{
private readonly GlobalSettings _globalSettings;
private readonly IWebHostEnvironment _hostingEnvironment;
private readonly ILogger<SendGridMailDeliveryService> _logger;
private readonly ISendGridClient _client;
private readonly string _senderTag;
private readonly string _replyToEmail;
public SendGridMailDeliveryService(
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
ILogger<SendGridMailDeliveryService> logger)
: this(new SendGridClient(globalSettings.Mail.SendGridApiKey),
globalSettings, hostingEnvironment, logger)
public class SendGridMailDeliveryService : IMailDeliveryService, IDisposable
{
}
private readonly GlobalSettings _globalSettings;
private readonly IWebHostEnvironment _hostingEnvironment;
private readonly ILogger<SendGridMailDeliveryService> _logger;
private readonly ISendGridClient _client;
private readonly string _senderTag;
private readonly string _replyToEmail;
public void Dispose()
{
// TODO: nothing to dispose
}
public SendGridMailDeliveryService(
ISendGridClient client,
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
ILogger<SendGridMailDeliveryService> logger)
{
if (string.IsNullOrWhiteSpace(globalSettings.Mail?.SendGridApiKey))
public SendGridMailDeliveryService(
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
ILogger<SendGridMailDeliveryService> logger)
: this(new SendGridClient(globalSettings.Mail.SendGridApiKey),
globalSettings, hostingEnvironment, logger)
{
throw new ArgumentNullException(nameof(globalSettings.Mail.SendGridApiKey));
}
_globalSettings = globalSettings;
_hostingEnvironment = hostingEnvironment;
_logger = logger;
_client = client;
_senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}";
_replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail);
}
public async Task SendEmailAsync(MailMessage message)
{
var msg = new SendGridMessage();
msg.SetFrom(new EmailAddress(_replyToEmail, _globalSettings.SiteName));
msg.AddTos(message.ToEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList());
if (message.BccEmails?.Any() ?? false)
public void Dispose()
{
msg.AddBccs(message.BccEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList());
// TODO: nothing to dispose
}
msg.SetSubject(message.Subject);
msg.AddContent(MimeType.Text, message.TextContent);
msg.AddContent(MimeType.Html, message.HtmlContent);
msg.AddCategory($"type:{message.Category}");
msg.AddCategory($"env:{_hostingEnvironment.EnvironmentName}");
msg.AddCategory($"sender:{_senderTag}");
msg.SetClickTracking(false, false);
msg.SetOpenTracking(false);
if (message.MetaData != null &&
message.MetaData.ContainsKey("SendGridBypassListManagement") &&
Convert.ToBoolean(message.MetaData["SendGridBypassListManagement"]))
public SendGridMailDeliveryService(
ISendGridClient client,
GlobalSettings globalSettings,
IWebHostEnvironment hostingEnvironment,
ILogger<SendGridMailDeliveryService> logger)
{
msg.SetBypassListManagement(true);
}
try
{
var success = await SendAsync(msg, false);
if (!success)
if (string.IsNullOrWhiteSpace(globalSettings.Mail?.SendGridApiKey))
{
_logger.LogWarning("Failed to send email. Retrying...");
throw new ArgumentNullException(nameof(globalSettings.Mail.SendGridApiKey));
}
_globalSettings = globalSettings;
_hostingEnvironment = hostingEnvironment;
_logger = logger;
_client = client;
_senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}";
_replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail);
}
public async Task SendEmailAsync(MailMessage message)
{
var msg = new SendGridMessage();
msg.SetFrom(new EmailAddress(_replyToEmail, _globalSettings.SiteName));
msg.AddTos(message.ToEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList());
if (message.BccEmails?.Any() ?? false)
{
msg.AddBccs(message.BccEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList());
}
msg.SetSubject(message.Subject);
msg.AddContent(MimeType.Text, message.TextContent);
msg.AddContent(MimeType.Html, message.HtmlContent);
msg.AddCategory($"type:{message.Category}");
msg.AddCategory($"env:{_hostingEnvironment.EnvironmentName}");
msg.AddCategory($"sender:{_senderTag}");
msg.SetClickTracking(false, false);
msg.SetOpenTracking(false);
if (message.MetaData != null &&
message.MetaData.ContainsKey("SendGridBypassListManagement") &&
Convert.ToBoolean(message.MetaData["SendGridBypassListManagement"]))
{
msg.SetBypassListManagement(true);
}
try
{
var success = await SendAsync(msg, false);
if (!success)
{
_logger.LogWarning("Failed to send email. Retrying...");
await SendAsync(msg, true);
}
}
catch (Exception e)
{
_logger.LogWarning(e, "Failed to send email (with exception). Retrying...");
await SendAsync(msg, true);
throw;
}
}
catch (Exception e)
{
_logger.LogWarning(e, "Failed to send email (with exception). Retrying...");
await SendAsync(msg, true);
throw;
}
}
private async Task<bool> SendAsync(SendGridMessage message, bool retry)
{
if (retry)
private async Task<bool> SendAsync(SendGridMessage message, bool retry)
{
// wait and try again
await Task.Delay(2000);
}
if (retry)
{
// wait and try again
await Task.Delay(2000);
}
var response = await _client.SendEmailAsync(message);
if (!response.IsSuccessStatusCode)
{
var responseBody = await response.Body.ReadAsStringAsync();
_logger.LogError("SendGrid email sending failed with {0}: {1}", response.StatusCode, responseBody);
var response = await _client.SendEmailAsync(message);
if (!response.IsSuccessStatusCode)
{
var responseBody = await response.Body.ReadAsStringAsync();
_logger.LogError("SendGrid email sending failed with {0}: {1}", response.StatusCode, responseBody);
}
return response.IsSuccessStatusCode;
}
return response.IsSuccessStatusCode;
}
}

View File

@ -11,329 +11,330 @@ using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Identity;
namespace Bit.Core.Services;
public class SendService : ISendService
namespace Bit.Core.Services
{
public const long MAX_FILE_SIZE = Constants.FileSize501mb;
public const string MAX_FILE_SIZE_READABLE = "500 MB";
private readonly ISendRepository _sendRepository;
private readonly IUserRepository _userRepository;
private readonly IPolicyRepository _policyRepository;
private readonly IUserService _userService;
private readonly IOrganizationRepository _organizationRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IPushNotificationService _pushService;
private readonly IReferenceEventService _referenceEventService;
private readonly GlobalSettings _globalSettings;
private readonly ICurrentContext _currentContext;
private const long _fileSizeLeeway = 1024L * 1024L; // 1MB
public SendService(
ISendRepository sendRepository,
IUserRepository userRepository,
IUserService userService,
IOrganizationRepository organizationRepository,
ISendFileStorageService sendFileStorageService,
IPasswordHasher<User> passwordHasher,
IPushNotificationService pushService,
IReferenceEventService referenceEventService,
GlobalSettings globalSettings,
IPolicyRepository policyRepository,
ICurrentContext currentContext)
public class SendService : ISendService
{
_sendRepository = sendRepository;
_userRepository = userRepository;
_userService = userService;
_policyRepository = policyRepository;
_organizationRepository = organizationRepository;
_sendFileStorageService = sendFileStorageService;
_passwordHasher = passwordHasher;
_pushService = pushService;
_referenceEventService = referenceEventService;
_globalSettings = globalSettings;
_currentContext = currentContext;
}
public const long MAX_FILE_SIZE = Constants.FileSize501mb;
public const string MAX_FILE_SIZE_READABLE = "500 MB";
private readonly ISendRepository _sendRepository;
private readonly IUserRepository _userRepository;
private readonly IPolicyRepository _policyRepository;
private readonly IUserService _userService;
private readonly IOrganizationRepository _organizationRepository;
private readonly ISendFileStorageService _sendFileStorageService;
private readonly IPasswordHasher<User> _passwordHasher;
private readonly IPushNotificationService _pushService;
private readonly IReferenceEventService _referenceEventService;
private readonly GlobalSettings _globalSettings;
private readonly ICurrentContext _currentContext;
private const long _fileSizeLeeway = 1024L * 1024L; // 1MB
public async Task SaveSendAsync(Send send)
{
// Make sure user can save Sends
await ValidateUserCanSaveAsync(send.UserId, send);
if (send.Id == default(Guid))
public SendService(
ISendRepository sendRepository,
IUserRepository userRepository,
IUserService userService,
IOrganizationRepository organizationRepository,
ISendFileStorageService sendFileStorageService,
IPasswordHasher<User> passwordHasher,
IPushNotificationService pushService,
IReferenceEventService referenceEventService,
GlobalSettings globalSettings,
IPolicyRepository policyRepository,
ICurrentContext currentContext)
{
await _sendRepository.CreateAsync(send);
await _pushService.PushSyncSendCreateAsync(send);
await RaiseReferenceEventAsync(send, ReferenceEventType.SendCreated);
}
else
{
send.RevisionDate = DateTime.UtcNow;
await _sendRepository.UpsertAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
}
}
public async Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Send is not of type \"file\".");
_sendRepository = sendRepository;
_userRepository = userRepository;
_userService = userService;
_policyRepository = policyRepository;
_organizationRepository = organizationRepository;
_sendFileStorageService = sendFileStorageService;
_passwordHasher = passwordHasher;
_pushService = pushService;
_referenceEventService = referenceEventService;
_globalSettings = globalSettings;
_currentContext = currentContext;
}
if (fileLength < 1)
public async Task SaveSendAsync(Send send)
{
throw new BadRequestException("No file data.");
}
// Make sure user can save Sends
await ValidateUserCanSaveAsync(send.UserId, send);
var storageBytesRemaining = await StorageRemainingForSendAsync(send);
if (storageBytesRemaining < fileLength)
{
throw new BadRequestException("Not enough storage available.");
}
var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false);
try
{
data.Id = fileId;
data.Size = fileLength;
data.Validated = false;
send.Data = JsonSerializer.Serialize(data,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId);
}
catch
{
// Clean up since this is not transactional
await _sendFileStorageService.DeleteFileAsync(send, fileId);
throw;
}
}
public async Task UploadFileToExistingSendAsync(Stream stream, Send send)
{
if (send?.Data == null)
{
throw new BadRequestException("Send does not have file data");
}
if (send.Type != SendType.File)
{
throw new BadRequestException("Not a File Type Send.");
}
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
if (data.Validated)
{
throw new BadRequestException("File has already been uploaded.");
}
await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id);
if (!await ValidateSendFile(send))
{
throw new BadRequestException("File received does not match expected file length.");
}
}
public async Task<bool> ValidateSendFile(Send send)
{
var fileData = JsonSerializer.Deserialize<SendFileData>(send.Data);
var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, _fileSizeLeeway);
if (!valid || realSize > MAX_FILE_SIZE)
{
// File reported differs in size from that promised. Must be a rogue client. Delete Send
await DeleteSendAsync(send);
return false;
}
// Update Send data if necessary
if (realSize != fileData.Size)
{
fileData.Size = realSize.Value;
}
fileData.Validated = true;
send.Data = JsonSerializer.Serialize(fileData,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return valid;
}
public async Task DeleteSendAsync(Send send)
{
await _sendRepository.DeleteAsync(send);
if (send.Type == Enums.SendType.File)
{
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
await _sendFileStorageService.DeleteFileAsync(send, data.Id);
}
await _pushService.PushSyncSendDeleteAsync(send);
}
public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send,
string password)
{
var now = DateTime.UtcNow;
if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount ||
send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled ||
send.DeletionDate < now)
{
return (false, false, false);
}
if (!string.IsNullOrWhiteSpace(send.Password))
{
if (string.IsNullOrWhiteSpace(password))
if (send.Id == default(Guid))
{
return (false, true, false);
}
var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password);
if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded)
{
send.Password = HashPassword(password);
}
if (passwordResult == PasswordVerificationResult.Failed)
{
return (false, false, true);
}
}
return (true, false, false);
}
// Response: Send, password required, password invalid
public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Can only get a download URL for a file type of Send");
}
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
if (!grantAccess)
{
return (null, passwordRequired, passwordInvalid);
}
send.AccessCount++;
await _sendRepository.ReplaceAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false);
}
// Response: Send, password required, password invalid
public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password)
{
var send = await _sendRepository.GetByIdAsync(sendId);
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
if (!grantAccess)
{
return (null, passwordRequired, passwordInvalid);
}
// TODO: maybe move this to a simple ++ sproc?
if (send.Type != SendType.File)
{
// File sends are incremented during file download
send.AccessCount++;
}
await _sendRepository.ReplaceAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed);
return (send, false, false);
}
private async Task RaiseReferenceEventAsync(Send send, ReferenceEventType eventType)
{
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
{
Id = send.UserId ?? default,
Type = eventType,
Source = ReferenceEventSource.User,
SendType = send.Type,
MaxAccessCount = send.MaxAccessCount,
HasPassword = !string.IsNullOrWhiteSpace(send.Password),
});
}
public string HashPassword(string password)
{
return _passwordHasher.HashPassword(new User(), password);
}
private async Task ValidateUserCanSaveAsync(Guid? userId, Send send)
{
if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true))
{
return;
}
var disableSendPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value,
PolicyType.DisableSend);
if (disableSendPolicyCount > 0)
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
if (send.HideEmail.GetValueOrDefault())
{
var sendOptionsPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId.Value, PolicyType.SendOptions);
if (sendOptionsPolicies.Any(p => p.GetDataModel<SendOptionsPolicyData>()?.DisableHideEmail ?? false))
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
}
private async Task<long> StorageRemainingForSendAsync(Send send)
{
var storageBytesRemaining = 0L;
if (send.UserId.HasValue)
{
var user = await _userRepository.GetByIdAsync(send.UserId.Value);
if (!await _userService.CanAccessPremium(user))
{
throw new BadRequestException("You must have premium status to use file Sends.");
}
if (!user.EmailVerified)
{
throw new BadRequestException("You must confirm your email to use file Sends.");
}
if (user.Premium)
{
storageBytesRemaining = user.StorageBytesRemaining();
await _sendRepository.CreateAsync(send);
await _pushService.PushSyncSendCreateAsync(send);
await RaiseReferenceEventAsync(send, ReferenceEventType.SendCreated);
}
else
{
// Users that get access to file storage/premium from their organization get the default
// 1 GB max storage.
storageBytesRemaining = user.StorageBytesRemaining(
_globalSettings.SelfHosted ? (short)10240 : (short)1);
send.RevisionDate = DateTime.UtcNow;
await _sendRepository.UpsertAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
}
}
else if (send.OrganizationId.HasValue)
public async Task<string> SaveFileSendAsync(Send send, SendFileData data, long fileLength)
{
var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
if (!org.MaxStorageGb.HasValue)
if (send.Type != SendType.File)
{
throw new BadRequestException("This organization cannot use file sends.");
throw new BadRequestException("Send is not of type \"file\".");
}
storageBytesRemaining = org.StorageBytesRemaining();
if (fileLength < 1)
{
throw new BadRequestException("No file data.");
}
var storageBytesRemaining = await StorageRemainingForSendAsync(send);
if (storageBytesRemaining < fileLength)
{
throw new BadRequestException("Not enough storage available.");
}
var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false);
try
{
data.Id = fileId;
data.Size = fileLength;
data.Validated = false;
send.Data = JsonSerializer.Serialize(data,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId);
}
catch
{
// Clean up since this is not transactional
await _sendFileStorageService.DeleteFileAsync(send, fileId);
throw;
}
}
return storageBytesRemaining;
public async Task UploadFileToExistingSendAsync(Stream stream, Send send)
{
if (send?.Data == null)
{
throw new BadRequestException("Send does not have file data");
}
if (send.Type != SendType.File)
{
throw new BadRequestException("Not a File Type Send.");
}
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
if (data.Validated)
{
throw new BadRequestException("File has already been uploaded.");
}
await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id);
if (!await ValidateSendFile(send))
{
throw new BadRequestException("File received does not match expected file length.");
}
}
public async Task<bool> ValidateSendFile(Send send)
{
var fileData = JsonSerializer.Deserialize<SendFileData>(send.Data);
var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, _fileSizeLeeway);
if (!valid || realSize > MAX_FILE_SIZE)
{
// File reported differs in size from that promised. Must be a rogue client. Delete Send
await DeleteSendAsync(send);
return false;
}
// Update Send data if necessary
if (realSize != fileData.Size)
{
fileData.Size = realSize.Value;
}
fileData.Validated = true;
send.Data = JsonSerializer.Serialize(fileData,
JsonHelpers.IgnoreWritingNull);
await SaveSendAsync(send);
return valid;
}
public async Task DeleteSendAsync(Send send)
{
await _sendRepository.DeleteAsync(send);
if (send.Type == Enums.SendType.File)
{
var data = JsonSerializer.Deserialize<SendFileData>(send.Data);
await _sendFileStorageService.DeleteFileAsync(send, data.Id);
}
await _pushService.PushSyncSendDeleteAsync(send);
}
public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send,
string password)
{
var now = DateTime.UtcNow;
if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount ||
send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled ||
send.DeletionDate < now)
{
return (false, false, false);
}
if (!string.IsNullOrWhiteSpace(send.Password))
{
if (string.IsNullOrWhiteSpace(password))
{
return (false, true, false);
}
var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password);
if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded)
{
send.Password = HashPassword(password);
}
if (passwordResult == PasswordVerificationResult.Failed)
{
return (false, false, true);
}
}
return (true, false, false);
}
// Response: Send, password required, password invalid
public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Can only get a download URL for a file type of Send");
}
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
if (!grantAccess)
{
return (null, passwordRequired, passwordInvalid);
}
send.AccessCount++;
await _sendRepository.ReplaceAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false);
}
// Response: Send, password required, password invalid
public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password)
{
var send = await _sendRepository.GetByIdAsync(sendId);
var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
if (!grantAccess)
{
return (null, passwordRequired, passwordInvalid);
}
// TODO: maybe move this to a simple ++ sproc?
if (send.Type != SendType.File)
{
// File sends are incremented during file download
send.AccessCount++;
}
await _sendRepository.ReplaceAsync(send);
await _pushService.PushSyncSendUpdateAsync(send);
await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed);
return (send, false, false);
}
private async Task RaiseReferenceEventAsync(Send send, ReferenceEventType eventType)
{
await _referenceEventService.RaiseEventAsync(new ReferenceEvent
{
Id = send.UserId ?? default,
Type = eventType,
Source = ReferenceEventSource.User,
SendType = send.Type,
MaxAccessCount = send.MaxAccessCount,
HasPassword = !string.IsNullOrWhiteSpace(send.Password),
});
}
public string HashPassword(string password)
{
return _passwordHasher.HashPassword(new User(), password);
}
private async Task ValidateUserCanSaveAsync(Guid? userId, Send send)
{
if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true))
{
return;
}
var disableSendPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value,
PolicyType.DisableSend);
if (disableSendPolicyCount > 0)
{
throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send.");
}
if (send.HideEmail.GetValueOrDefault())
{
var sendOptionsPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId.Value, PolicyType.SendOptions);
if (sendOptionsPolicies.Any(p => p.GetDataModel<SendOptionsPolicyData>()?.DisableHideEmail ?? false))
{
throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send.");
}
}
}
private async Task<long> StorageRemainingForSendAsync(Send send)
{
var storageBytesRemaining = 0L;
if (send.UserId.HasValue)
{
var user = await _userRepository.GetByIdAsync(send.UserId.Value);
if (!await _userService.CanAccessPremium(user))
{
throw new BadRequestException("You must have premium status to use file Sends.");
}
if (!user.EmailVerified)
{
throw new BadRequestException("You must confirm your email to use file Sends.");
}
if (user.Premium)
{
storageBytesRemaining = user.StorageBytesRemaining();
}
else
{
// Users that get access to file storage/premium from their organization get the default
// 1 GB max storage.
storageBytesRemaining = user.StorageBytesRemaining(
_globalSettings.SelfHosted ? (short)10240 : (short)1);
}
}
else if (send.OrganizationId.HasValue)
{
var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
if (!org.MaxStorageGb.HasValue)
{
throw new BadRequestException("This organization cannot use file sends.");
}
storageBytesRemaining = org.StorageBytesRemaining();
}
return storageBytesRemaining;
}
}
}

View File

@ -3,104 +3,105 @@ using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Repositories;
namespace Bit.Core.Services;
public class SsoConfigService : ISsoConfigService
namespace Bit.Core.Services
{
private readonly ISsoConfigRepository _ssoConfigRepository;
private readonly IPolicyRepository _policyRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IEventService _eventService;
public SsoConfigService(
ISsoConfigRepository ssoConfigRepository,
IPolicyRepository policyRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IEventService eventService)
public class SsoConfigService : ISsoConfigService
{
_ssoConfigRepository = ssoConfigRepository;
_policyRepository = policyRepository;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_eventService = eventService;
}
private readonly ISsoConfigRepository _ssoConfigRepository;
private readonly IPolicyRepository _policyRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly IEventService _eventService;
public async Task SaveAsync(SsoConfig config, Organization organization)
{
var now = DateTime.UtcNow;
config.RevisionDate = now;
if (config.Id == default)
public SsoConfigService(
ISsoConfigRepository ssoConfigRepository,
IPolicyRepository policyRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IEventService eventService)
{
config.CreationDate = now;
_ssoConfigRepository = ssoConfigRepository;
_policyRepository = policyRepository;
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_eventService = eventService;
}
var useKeyConnector = config.GetData().KeyConnectorEnabled;
if (useKeyConnector)
public async Task SaveAsync(SsoConfig config, Organization organization)
{
await VerifyDependenciesAsync(config, organization);
var now = DateTime.UtcNow;
config.RevisionDate = now;
if (config.Id == default)
{
config.CreationDate = now;
}
var useKeyConnector = config.GetData().KeyConnectorEnabled;
if (useKeyConnector)
{
await VerifyDependenciesAsync(config, organization);
}
var oldConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(config.OrganizationId);
var disabledKeyConnector = oldConfig?.GetData()?.KeyConnectorEnabled == true && !useKeyConnector;
if (disabledKeyConnector && await AnyOrgUserHasKeyConnectorEnabledAsync(config.OrganizationId))
{
throw new BadRequestException("Key Connector cannot be disabled at this moment.");
}
await LogEventsAsync(config, oldConfig);
await _ssoConfigRepository.UpsertAsync(config);
}
var oldConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(config.OrganizationId);
var disabledKeyConnector = oldConfig?.GetData()?.KeyConnectorEnabled == true && !useKeyConnector;
if (disabledKeyConnector && await AnyOrgUserHasKeyConnectorEnabledAsync(config.OrganizationId))
private async Task<bool> AnyOrgUserHasKeyConnectorEnabledAsync(Guid organizationId)
{
throw new BadRequestException("Key Connector cannot be disabled at this moment.");
var userDetails =
await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
return userDetails.Any(u => u.UsesKeyConnector);
}
await LogEventsAsync(config, oldConfig);
await _ssoConfigRepository.UpsertAsync(config);
}
private async Task<bool> AnyOrgUserHasKeyConnectorEnabledAsync(Guid organizationId)
{
var userDetails =
await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId);
return userDetails.Any(u => u.UsesKeyConnector);
}
private async Task VerifyDependenciesAsync(SsoConfig config, Organization organization)
{
if (!organization.UseKeyConnector)
private async Task VerifyDependenciesAsync(SsoConfig config, Organization organization)
{
throw new BadRequestException("Organization cannot use Key Connector.");
if (!organization.UseKeyConnector)
{
throw new BadRequestException("Organization cannot use Key Connector.");
}
var singleOrgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.SingleOrg);
if (singleOrgPolicy is not { Enabled: true })
{
throw new BadRequestException("Key Connector requires the Single Organization policy to be enabled.");
}
var ssoPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.RequireSso);
if (ssoPolicy is not { Enabled: true })
{
throw new BadRequestException("Key Connector requires the Single Sign-On Authentication policy to be enabled.");
}
if (!config.Enabled)
{
throw new BadRequestException("You must enable SSO to use Key Connector.");
}
}
var singleOrgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.SingleOrg);
if (singleOrgPolicy is not { Enabled: true })
private async Task LogEventsAsync(SsoConfig config, SsoConfig oldConfig)
{
throw new BadRequestException("Key Connector requires the Single Organization policy to be enabled.");
}
var organization = await _organizationRepository.GetByIdAsync(config.OrganizationId);
if (oldConfig?.Enabled != config.Enabled)
{
var e = config.Enabled ? EventType.Organization_EnabledSso : EventType.Organization_DisabledSso;
await _eventService.LogOrganizationEventAsync(organization, e);
}
var ssoPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.RequireSso);
if (ssoPolicy is not { Enabled: true })
{
throw new BadRequestException("Key Connector requires the Single Sign-On Authentication policy to be enabled.");
}
if (!config.Enabled)
{
throw new BadRequestException("You must enable SSO to use Key Connector.");
}
}
private async Task LogEventsAsync(SsoConfig config, SsoConfig oldConfig)
{
var organization = await _organizationRepository.GetByIdAsync(config.OrganizationId);
if (oldConfig?.Enabled != config.Enabled)
{
var e = config.Enabled ? EventType.Organization_EnabledSso : EventType.Organization_DisabledSso;
await _eventService.LogOrganizationEventAsync(organization, e);
}
var keyConnectorEnabled = config.GetData().KeyConnectorEnabled;
if (oldConfig?.GetData()?.KeyConnectorEnabled != keyConnectorEnabled)
{
var e = keyConnectorEnabled
? EventType.Organization_EnabledKeyConnector
: EventType.Organization_DisabledKeyConnector;
await _eventService.LogOrganizationEventAsync(organization, e);
var keyConnectorEnabled = config.GetData().KeyConnectorEnabled;
if (oldConfig?.GetData()?.KeyConnectorEnabled != keyConnectorEnabled)
{
var e = keyConnectorEnabled
? EventType.Organization_EnabledKeyConnector
: EventType.Organization_DisabledKeyConnector;
await _eventService.LogOrganizationEventAsync(organization, e);
}
}
}
}

View File

@ -1,217 +1,218 @@
using Bit.Core.Models.BitStripe;
namespace Bit.Core.Services;
public class StripeAdapter : IStripeAdapter
namespace Bit.Core.Services
{
private readonly Stripe.CustomerService _customerService;
private readonly Stripe.SubscriptionService _subscriptionService;
private readonly Stripe.InvoiceService _invoiceService;
private readonly Stripe.PaymentMethodService _paymentMethodService;
private readonly Stripe.TaxRateService _taxRateService;
private readonly Stripe.TaxIdService _taxIdService;
private readonly Stripe.ChargeService _chargeService;
private readonly Stripe.RefundService _refundService;
private readonly Stripe.CardService _cardService;
private readonly Stripe.BankAccountService _bankAccountService;
private readonly Stripe.PriceService _priceService;
private readonly Stripe.TestHelpers.TestClockService _testClockService;
public StripeAdapter()
public class StripeAdapter : IStripeAdapter
{
_customerService = new Stripe.CustomerService();
_subscriptionService = new Stripe.SubscriptionService();
_invoiceService = new Stripe.InvoiceService();
_paymentMethodService = new Stripe.PaymentMethodService();
_taxRateService = new Stripe.TaxRateService();
_taxIdService = new Stripe.TaxIdService();
_chargeService = new Stripe.ChargeService();
_refundService = new Stripe.RefundService();
_cardService = new Stripe.CardService();
_bankAccountService = new Stripe.BankAccountService();
_priceService = new Stripe.PriceService();
_testClockService = new Stripe.TestHelpers.TestClockService();
}
private readonly Stripe.CustomerService _customerService;
private readonly Stripe.SubscriptionService _subscriptionService;
private readonly Stripe.InvoiceService _invoiceService;
private readonly Stripe.PaymentMethodService _paymentMethodService;
private readonly Stripe.TaxRateService _taxRateService;
private readonly Stripe.TaxIdService _taxIdService;
private readonly Stripe.ChargeService _chargeService;
private readonly Stripe.RefundService _refundService;
private readonly Stripe.CardService _cardService;
private readonly Stripe.BankAccountService _bankAccountService;
private readonly Stripe.PriceService _priceService;
private readonly Stripe.TestHelpers.TestClockService _testClockService;
public Task<Stripe.Customer> CustomerCreateAsync(Stripe.CustomerCreateOptions options)
{
return _customerService.CreateAsync(options);
}
public Task<Stripe.Customer> CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null)
{
return _customerService.GetAsync(id, options);
}
public Task<Stripe.Customer> CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null)
{
return _customerService.UpdateAsync(id, options);
}
public Task<Stripe.Customer> CustomerDeleteAsync(string id)
{
return _customerService.DeleteAsync(id);
}
public Task<Stripe.Subscription> SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options)
{
return _subscriptionService.CreateAsync(options);
}
public Task<Stripe.Subscription> SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null)
{
return _subscriptionService.GetAsync(id, options);
}
public Task<Stripe.Subscription> SubscriptionUpdateAsync(string id,
Stripe.SubscriptionUpdateOptions options = null)
{
return _subscriptionService.UpdateAsync(id, options);
}
public Task<Stripe.Subscription> SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null)
{
return _subscriptionService.CancelAsync(Id, options);
}
public Task<Stripe.Invoice> InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options)
{
return _invoiceService.UpcomingAsync(options);
}
public Task<Stripe.Invoice> InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options)
{
return _invoiceService.GetAsync(id, options);
}
public Task<Stripe.StripeList<Stripe.Invoice>> InvoiceListAsync(Stripe.InvoiceListOptions options)
{
return _invoiceService.ListAsync(options);
}
public Task<Stripe.Invoice> InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options)
{
return _invoiceService.UpdateAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options)
{
return _invoiceService.FinalizeInvoiceAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options)
{
return _invoiceService.SendInvoiceAsync(id, options);
}
public Task<Stripe.Invoice> InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null)
{
return _invoiceService.PayAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null)
{
return _invoiceService.DeleteAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null)
{
return _invoiceService.VoidInvoiceAsync(id, options);
}
public IEnumerable<Stripe.PaymentMethod> PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options)
{
return _paymentMethodService.ListAutoPaging(options);
}
public Task<Stripe.PaymentMethod> PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null)
{
return _paymentMethodService.AttachAsync(id, options);
}
public Task<Stripe.PaymentMethod> PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null)
{
return _paymentMethodService.DetachAsync(id, options);
}
public Task<Stripe.TaxRate> TaxRateCreateAsync(Stripe.TaxRateCreateOptions options)
{
return _taxRateService.CreateAsync(options);
}
public Task<Stripe.TaxRate> TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options)
{
return _taxRateService.UpdateAsync(id, options);
}
public Task<Stripe.TaxId> TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options)
{
return _taxIdService.CreateAsync(id, options);
}
public Task<Stripe.TaxId> TaxIdDeleteAsync(string customerId, string taxIdId,
Stripe.TaxIdDeleteOptions options = null)
{
return _taxIdService.DeleteAsync(customerId, taxIdId);
}
public Task<Stripe.StripeList<Stripe.Charge>> ChargeListAsync(Stripe.ChargeListOptions options)
{
return _chargeService.ListAsync(options);
}
public Task<Stripe.Refund> RefundCreateAsync(Stripe.RefundCreateOptions options)
{
return _refundService.CreateAsync(options);
}
public Task<Stripe.Card> CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null)
{
return _cardService.DeleteAsync(customerId, cardId, options);
}
public Task<Stripe.BankAccount> BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null)
{
return _bankAccountService.CreateAsync(customerId, options);
}
public Task<Stripe.BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null)
{
return _bankAccountService.DeleteAsync(customerId, bankAccount, options);
}
public async Task<List<Stripe.Subscription>> SubscriptionListAsync(StripeSubscriptionListOptions options)
{
if (!options.SelectAll)
public StripeAdapter()
{
return (await _subscriptionService.ListAsync(options.ToStripeApiOptions())).Data;
_customerService = new Stripe.CustomerService();
_subscriptionService = new Stripe.SubscriptionService();
_invoiceService = new Stripe.InvoiceService();
_paymentMethodService = new Stripe.PaymentMethodService();
_taxRateService = new Stripe.TaxRateService();
_taxIdService = new Stripe.TaxIdService();
_chargeService = new Stripe.ChargeService();
_refundService = new Stripe.RefundService();
_cardService = new Stripe.CardService();
_bankAccountService = new Stripe.BankAccountService();
_priceService = new Stripe.PriceService();
_testClockService = new Stripe.TestHelpers.TestClockService();
}
options.Limit = 100;
var items = new List<Stripe.Subscription>();
await foreach (var i in _subscriptionService.ListAutoPagingAsync(options.ToStripeApiOptions()))
public Task<Stripe.Customer> CustomerCreateAsync(Stripe.CustomerCreateOptions options)
{
items.Add(i);
return _customerService.CreateAsync(options);
}
return items;
}
public async Task<Stripe.StripeList<Stripe.Price>> PriceListAsync(Stripe.PriceListOptions options = null)
{
return await _priceService.ListAsync(options);
}
public async Task<List<Stripe.TestHelpers.TestClock>> TestClockListAsync()
{
var items = new List<Stripe.TestHelpers.TestClock>();
var options = new Stripe.TestHelpers.TestClockListOptions()
public Task<Stripe.Customer> CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null)
{
Limit = 100
};
await foreach (var i in _testClockService.ListAutoPagingAsync(options))
{
items.Add(i);
return _customerService.GetAsync(id, options);
}
public Task<Stripe.Customer> CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null)
{
return _customerService.UpdateAsync(id, options);
}
public Task<Stripe.Customer> CustomerDeleteAsync(string id)
{
return _customerService.DeleteAsync(id);
}
public Task<Stripe.Subscription> SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options)
{
return _subscriptionService.CreateAsync(options);
}
public Task<Stripe.Subscription> SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null)
{
return _subscriptionService.GetAsync(id, options);
}
public Task<Stripe.Subscription> SubscriptionUpdateAsync(string id,
Stripe.SubscriptionUpdateOptions options = null)
{
return _subscriptionService.UpdateAsync(id, options);
}
public Task<Stripe.Subscription> SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null)
{
return _subscriptionService.CancelAsync(Id, options);
}
public Task<Stripe.Invoice> InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options)
{
return _invoiceService.UpcomingAsync(options);
}
public Task<Stripe.Invoice> InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options)
{
return _invoiceService.GetAsync(id, options);
}
public Task<Stripe.StripeList<Stripe.Invoice>> InvoiceListAsync(Stripe.InvoiceListOptions options)
{
return _invoiceService.ListAsync(options);
}
public Task<Stripe.Invoice> InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options)
{
return _invoiceService.UpdateAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options)
{
return _invoiceService.FinalizeInvoiceAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options)
{
return _invoiceService.SendInvoiceAsync(id, options);
}
public Task<Stripe.Invoice> InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null)
{
return _invoiceService.PayAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null)
{
return _invoiceService.DeleteAsync(id, options);
}
public Task<Stripe.Invoice> InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null)
{
return _invoiceService.VoidInvoiceAsync(id, options);
}
public IEnumerable<Stripe.PaymentMethod> PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options)
{
return _paymentMethodService.ListAutoPaging(options);
}
public Task<Stripe.PaymentMethod> PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null)
{
return _paymentMethodService.AttachAsync(id, options);
}
public Task<Stripe.PaymentMethod> PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null)
{
return _paymentMethodService.DetachAsync(id, options);
}
public Task<Stripe.TaxRate> TaxRateCreateAsync(Stripe.TaxRateCreateOptions options)
{
return _taxRateService.CreateAsync(options);
}
public Task<Stripe.TaxRate> TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options)
{
return _taxRateService.UpdateAsync(id, options);
}
public Task<Stripe.TaxId> TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options)
{
return _taxIdService.CreateAsync(id, options);
}
public Task<Stripe.TaxId> TaxIdDeleteAsync(string customerId, string taxIdId,
Stripe.TaxIdDeleteOptions options = null)
{
return _taxIdService.DeleteAsync(customerId, taxIdId);
}
public Task<Stripe.StripeList<Stripe.Charge>> ChargeListAsync(Stripe.ChargeListOptions options)
{
return _chargeService.ListAsync(options);
}
public Task<Stripe.Refund> RefundCreateAsync(Stripe.RefundCreateOptions options)
{
return _refundService.CreateAsync(options);
}
public Task<Stripe.Card> CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null)
{
return _cardService.DeleteAsync(customerId, cardId, options);
}
public Task<Stripe.BankAccount> BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null)
{
return _bankAccountService.CreateAsync(customerId, options);
}
public Task<Stripe.BankAccount> BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null)
{
return _bankAccountService.DeleteAsync(customerId, bankAccount, options);
}
public async Task<List<Stripe.Subscription>> SubscriptionListAsync(StripeSubscriptionListOptions options)
{
if (!options.SelectAll)
{
return (await _subscriptionService.ListAsync(options.ToStripeApiOptions())).Data;
}
options.Limit = 100;
var items = new List<Stripe.Subscription>();
await foreach (var i in _subscriptionService.ListAutoPagingAsync(options.ToStripeApiOptions()))
{
items.Add(i);
}
return items;
}
public async Task<Stripe.StripeList<Stripe.Price>> PriceListAsync(Stripe.PriceListOptions options = null)
{
return await _priceService.ListAsync(options);
}
public async Task<List<Stripe.TestHelpers.TestClock>> TestClockListAsync()
{
var items = new List<Stripe.TestHelpers.TestClock>();
var options = new Stripe.TestHelpers.TestClockListOptions()
{
Limit = 100
};
await foreach (var i in _testClockService.ListAutoPagingAsync(options))
{
items.Add(i);
}
return items;
}
return items;
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,31 +1,32 @@
using Bit.Core.Exceptions;
namespace Bit.Core.Services;
public class StripeSyncService : IStripeSyncService
namespace Bit.Core.Services
{
private readonly IStripeAdapter _stripeAdapter;
public StripeSyncService(IStripeAdapter stripeAdapter)
public class StripeSyncService : IStripeSyncService
{
_stripeAdapter = stripeAdapter;
}
private readonly IStripeAdapter _stripeAdapter;
public async Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress)
{
if (string.IsNullOrWhiteSpace(gatewayCustomerId))
public StripeSyncService(IStripeAdapter stripeAdapter)
{
throw new InvalidGatewayCustomerIdException();
_stripeAdapter = stripeAdapter;
}
if (string.IsNullOrWhiteSpace(emailAddress))
public async Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress)
{
throw new InvalidEmailException();
if (string.IsNullOrWhiteSpace(gatewayCustomerId))
{
throw new InvalidGatewayCustomerIdException();
}
if (string.IsNullOrWhiteSpace(emailAddress))
{
throw new InvalidEmailException();
}
var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId);
await _stripeAdapter.CustomerUpdateAsync(customer.Id,
new Stripe.CustomerUpdateOptions { Email = emailAddress });
}
var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId);
await _stripeAdapter.CustomerUpdateAsync(customer.Id,
new Stripe.CustomerUpdateOptions { Email = emailAddress });
}
}

File diff suppressed because it is too large Load Diff