diff --git a/src/Api/Platform/Push/Controllers/PushController.cs b/src/Api/Platform/Push/Controllers/PushController.cs index 4b9f1c3e11..8b9e8b52a0 100644 --- a/src/Api/Platform/Push/Controllers/PushController.cs +++ b/src/Api/Platform/Push/Controllers/PushController.cs @@ -43,7 +43,7 @@ public class PushController : Controller { CheckUsage(); await _pushRegistrationService.CreateOrUpdateRegistrationAsync(model.PushToken, Prefix(model.DeviceId), - Prefix(model.UserId), Prefix(model.Identifier), model.Type); + Prefix(model.UserId), Prefix(model.Identifier), model.Type, model.OrganizationIds.Select(Prefix)); } [HttpPost("delete")] @@ -79,12 +79,12 @@ public class PushController : Controller if (!string.IsNullOrWhiteSpace(model.UserId)) { await _pushNotificationService.SendPayloadToUserAsync(Prefix(model.UserId), - model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); + model.Type, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId), model.ClientType); } else if (!string.IsNullOrWhiteSpace(model.OrganizationId)) { await _pushNotificationService.SendPayloadToOrganizationAsync(Prefix(model.OrganizationId), - model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); + model.Type, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId), model.ClientType); } } diff --git a/src/Core/Context/CurrentContext.cs b/src/Core/Context/CurrentContext.cs index 2767b5925f..b4a250fe2b 100644 --- a/src/Core/Context/CurrentContext.cs +++ b/src/Core/Context/CurrentContext.cs @@ -169,6 +169,11 @@ public class CurrentContext : ICurrentContext DeviceIdentifier = GetClaimValue(claimsDict, Claims.Device); + if (Enum.TryParse(GetClaimValue(claimsDict, Claims.DeviceType), out DeviceType deviceType)) + { + DeviceType = deviceType; + } + Organizations = GetOrganizations(claimsDict, orgApi); Providers = GetProviders(claimsDict); diff --git a/src/Core/Enums/PushType.cs b/src/Core/Enums/PushType.cs index ee1b59990f..b656e70601 100644 --- a/src/Core/Enums/PushType.cs +++ b/src/Core/Enums/PushType.cs @@ -27,4 +27,6 @@ public enum PushType : byte SyncOrganizations = 17, SyncOrganizationStatusChanged = 18, SyncOrganizationCollectionSettingChanged = 19, + + SyncNotification = 20, } diff --git a/src/Core/Identity/Claims.cs b/src/Core/Identity/Claims.cs index b1223a6e63..65d5eb210a 100644 --- a/src/Core/Identity/Claims.cs +++ b/src/Core/Identity/Claims.cs @@ -6,6 +6,7 @@ public static class Claims public const string SecurityStamp = "sstamp"; public const string Premium = "premium"; public const string Device = "device"; + public const string DeviceType = "devicetype"; public const string OrganizationOwner = "orgowner"; public const string OrganizationAdmin = "orgadmin"; diff --git a/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs b/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs index 580c1c3b60..ee787dd083 100644 --- a/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs +++ b/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs @@ -15,4 +15,5 @@ public class PushRegistrationRequestModel public DeviceType Type { get; set; } [Required] public string Identifier { get; set; } + public IEnumerable OrganizationIds { get; set; } } diff --git a/src/Core/Models/Api/Request/PushSendRequestModel.cs b/src/Core/Models/Api/Request/PushSendRequestModel.cs index b85c8fb555..7247e6d25f 100644 --- a/src/Core/Models/Api/Request/PushSendRequestModel.cs +++ b/src/Core/Models/Api/Request/PushSendRequestModel.cs @@ -1,18 +1,18 @@ -using System.ComponentModel.DataAnnotations; +#nullable enable +using System.ComponentModel.DataAnnotations; using Bit.Core.Enums; namespace Bit.Core.Models.Api; public class PushSendRequestModel : IValidatableObject { - public string UserId { get; set; } - public string OrganizationId { get; set; } - public string DeviceId { get; set; } - public string Identifier { get; set; } - [Required] - public PushType? Type { get; set; } - [Required] - public object Payload { get; set; } + public string? UserId { get; set; } + public string? OrganizationId { get; set; } + public string? DeviceId { get; set; } + public string? Identifier { get; set; } + public required PushType Type { get; set; } + public required object Payload { get; set; } + public ClientType? ClientType { get; set; } public IEnumerable Validate(ValidationContext validationContext) { diff --git a/src/Core/Models/PushNotification.cs b/src/Core/Models/PushNotification.cs index e2247881ea..fd27ced6c5 100644 --- a/src/Core/Models/PushNotification.cs +++ b/src/Core/Models/PushNotification.cs @@ -45,6 +45,15 @@ public class SyncSendPushNotification public DateTime RevisionDate { get; set; } } +public class SyncNotificationPushNotification +{ + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public ClientType ClientType { get; set; } + public DateTime RevisionDate { get; set; } +} + public class AuthRequestPushNotification { public Guid UserId { get; set; } diff --git a/src/Core/NotificationCenter/Commands/CreateNotificationCommand.cs b/src/Core/NotificationCenter/Commands/CreateNotificationCommand.cs index 4f76950a34..f378a3688a 100644 --- a/src/Core/NotificationCenter/Commands/CreateNotificationCommand.cs +++ b/src/Core/NotificationCenter/Commands/CreateNotificationCommand.cs @@ -4,6 +4,7 @@ using Bit.Core.NotificationCenter.Authorization; using Bit.Core.NotificationCenter.Commands.Interfaces; using Bit.Core.NotificationCenter.Entities; using Bit.Core.NotificationCenter.Repositories; +using Bit.Core.Platform.Push; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; @@ -14,14 +15,17 @@ public class CreateNotificationCommand : ICreateNotificationCommand private readonly ICurrentContext _currentContext; private readonly IAuthorizationService _authorizationService; private readonly INotificationRepository _notificationRepository; + private readonly IPushNotificationService _pushNotificationService; public CreateNotificationCommand(ICurrentContext currentContext, IAuthorizationService authorizationService, - INotificationRepository notificationRepository) + INotificationRepository notificationRepository, + IPushNotificationService pushNotificationService) { _currentContext = currentContext; _authorizationService = authorizationService; _notificationRepository = notificationRepository; + _pushNotificationService = pushNotificationService; } public async Task CreateAsync(Notification notification) @@ -31,6 +35,10 @@ public class CreateNotificationCommand : ICreateNotificationCommand await _authorizationService.AuthorizeOrThrowAsync(_currentContext.HttpContext.User, notification, NotificationOperations.Create); - return await _notificationRepository.CreateAsync(notification); + var newNotification = await _notificationRepository.CreateAsync(notification); + + await _pushNotificationService.PushSyncNotificationAsync(newNotification); + + return newNotification; } } diff --git a/src/Core/NotificationHub/INotificationHubPool.cs b/src/Core/NotificationHub/INotificationHubPool.cs index 7c383d7b96..18bae98bc6 100644 --- a/src/Core/NotificationHub/INotificationHubPool.cs +++ b/src/Core/NotificationHub/INotificationHubPool.cs @@ -4,6 +4,6 @@ namespace Bit.Core.NotificationHub; public interface INotificationHubPool { - NotificationHubClient ClientFor(Guid comb); + INotificationHubClient ClientFor(Guid comb); INotificationHubProxy AllClients { get; } } diff --git a/src/Core/NotificationHub/NotificationHubPool.cs b/src/Core/NotificationHub/NotificationHubPool.cs index 7448aad5bd..8993ee2b8e 100644 --- a/src/Core/NotificationHub/NotificationHubPool.cs +++ b/src/Core/NotificationHub/NotificationHubPool.cs @@ -43,7 +43,7 @@ public class NotificationHubPool : INotificationHubPool /// /// /// Thrown when no notification hub is found for a given comb. - public NotificationHubClient ClientFor(Guid comb) + public INotificationHubClient ClientFor(Guid comb) { var possibleConnections = _connections.Where(c => c.RegistrationEnabled(comb)).ToArray(); if (possibleConnections.Length == 0) diff --git a/src/Core/NotificationHub/NotificationHubPushNotificationService.cs b/src/Core/NotificationHub/NotificationHubPushNotificationService.cs index d99cbf3fe7..ed44e69218 100644 --- a/src/Core/NotificationHub/NotificationHubPushNotificationService.cs +++ b/src/Core/NotificationHub/NotificationHubPushNotificationService.cs @@ -12,6 +12,7 @@ using Bit.Core.Tools.Entities; using Bit.Core.Vault.Entities; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; +using Notification = Bit.Core.NotificationCenter.Entities.Notification; namespace Bit.Core.NotificationHub; @@ -135,11 +136,7 @@ public class NotificationHubPushNotificationService : IPushNotificationService private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false) { - var message = new UserPushNotification - { - UserId = userId, - Date = DateTime.UtcNow - }; + var message = new UserPushNotification { UserId = userId, Date = DateTime.UtcNow }; await SendPayloadToUserAsync(userId, type, message, excludeCurrentContext); } @@ -184,31 +181,54 @@ public class NotificationHubPushNotificationService : IPushNotificationService await PushAuthRequestAsync(authRequest, PushType.AuthRequestResponse); } + public async Task PushSyncNotificationAsync(Notification notification) + { + var message = new SyncNotificationPushNotification + { + Id = notification.Id, + UserId = notification.UserId, + OrganizationId = notification.OrganizationId, + ClientType = notification.ClientType, + RevisionDate = notification.RevisionDate + }; + + if (notification.UserId.HasValue) + { + await SendPayloadToUserAsync(notification.UserId.Value, PushType.SyncNotification, message, true, + notification.ClientType); + } + else if (notification.OrganizationId.HasValue) + { + await SendPayloadToOrganizationAsync(notification.OrganizationId.Value, PushType.SyncNotification, message, + true, notification.ClientType); + } + } + private async Task PushAuthRequestAsync(AuthRequest authRequest, PushType type) { - var message = new AuthRequestPushNotification - { - Id = authRequest.Id, - UserId = authRequest.UserId - }; + var message = new AuthRequestPushNotification { Id = authRequest.Id, UserId = authRequest.UserId }; await SendPayloadToUserAsync(authRequest.UserId, type, message, true); } - private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) + private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext, + ClientType? clientType = null) { - await SendPayloadToUserAsync(userId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); + await SendPayloadToUserAsync(userId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext), + clientType: clientType); } - private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) + private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, + bool excludeCurrentContext, ClientType? clientType = null) { - await SendPayloadToUserAsync(orgId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); + await SendPayloadToOrganizationAsync(orgId.ToString(), type, payload, + GetContextIdentifier(excludeCurrentContext), clientType: clientType); } public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) + string deviceId = null, ClientType? clientType = null) { - var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier); + var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier, clientType); await SendPayloadAsync(tag, type, payload); if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) { @@ -217,9 +237,9 @@ public class NotificationHubPushNotificationService : IPushNotificationService } public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) + string deviceId = null, ClientType? clientType = null) { - var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier); + var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier, clientType); await SendPayloadAsync(tag, type, payload); if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) { @@ -259,18 +279,23 @@ public class NotificationHubPushNotificationService : IPushNotificationService return null; } - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + var currentContext = + _httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; return currentContext?.DeviceIdentifier; } - private string BuildTag(string tag, string identifier) + private string BuildTag(string tag, string identifier, ClientType? clientType) { if (!string.IsNullOrWhiteSpace(identifier)) { tag += $" && !deviceIdentifier:{SanitizeTagInput(identifier)}"; } + if (clientType.HasValue && clientType.Value != ClientType.All) + { + tag += $" && clientType:{clientType}"; + } + return $"({tag})"; } @@ -279,8 +304,7 @@ public class NotificationHubPushNotificationService : IPushNotificationService var results = await _notificationHubPool.AllClients.SendTemplateNotificationAsync( new Dictionary { - { "type", ((byte)type).ToString() }, - { "payload", JsonSerializer.Serialize(payload) } + { "type", ((byte)type).ToString() }, { "payload", JsonSerializer.Serialize(payload) } }, tag); if (_enableTracing) @@ -291,7 +315,9 @@ public class NotificationHubPushNotificationService : IPushNotificationService { continue; } - _logger.LogInformation("Azure Notification Hub Tracking ID: {Id} | {Type} push notification with {Success} successes and {Failure} failures with a payload of {@Payload} and result of {@Results}", + + _logger.LogInformation( + "Azure Notification Hub Tracking ID: {Id} | {Type} push notification with {Success} successes and {Failure} failures with a payload of {@Payload} and result of {@Results}", outcome.TrackingId, type, outcome.Success, outcome.Failure, payload, outcome.Results); } } diff --git a/src/Core/NotificationHub/NotificationHubPushRegistrationService.cs b/src/Core/NotificationHub/NotificationHubPushRegistrationService.cs index 180b2b641b..0c9bbea425 100644 --- a/src/Core/NotificationHub/NotificationHubPushRegistrationService.cs +++ b/src/Core/NotificationHub/NotificationHubPushRegistrationService.cs @@ -2,36 +2,26 @@ using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; -using Bit.Core.Settings; +using Bit.Core.Utilities; using Microsoft.Azure.NotificationHubs; -using Microsoft.Extensions.Logging; namespace Bit.Core.NotificationHub; public class NotificationHubPushRegistrationService : IPushRegistrationService { private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; private readonly INotificationHubPool _notificationHubPool; - private readonly IServiceProvider _serviceProvider; - private readonly ILogger _logger; public NotificationHubPushRegistrationService( IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings, - INotificationHubPool notificationHubPool, - IServiceProvider serviceProvider, - ILogger logger) + INotificationHubPool notificationHubPool) { _installationDeviceRepository = installationDeviceRepository; - _globalSettings = globalSettings; _notificationHubPool = notificationHubPool; - _serviceProvider = serviceProvider; - _logger = logger; } public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type) + string identifier, DeviceType type, IEnumerable organizationIds) { if (string.IsNullOrWhiteSpace(pushToken)) { @@ -45,16 +35,21 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService Templates = new Dictionary() }; - installation.Tags = new List - { - $"userId:{userId}" - }; + var clientType = DeviceTypes.ToClientType(type); + + installation.Tags = new List { $"userId:{userId}", $"clientType:{clientType}" }; if (!string.IsNullOrWhiteSpace(identifier)) { installation.Tags.Add("deviceIdentifier:" + identifier); } + var organizationIdsList = organizationIds.ToList(); + foreach (var organizationId in organizationIdsList) + { + installation.Tags.Add($"organizationId:{organizationId}"); + } + string payloadTemplate = null, messageTemplate = null, badgeMessageTemplate = null; switch (type) { @@ -84,10 +79,12 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService break; } - BuildInstallationTemplate(installation, "payload", payloadTemplate, userId, identifier); - BuildInstallationTemplate(installation, "message", messageTemplate, userId, identifier); + BuildInstallationTemplate(installation, "payload", payloadTemplate, userId, identifier, clientType, + organizationIdsList); + BuildInstallationTemplate(installation, "message", messageTemplate, userId, identifier, clientType, + organizationIdsList); BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate, - userId, identifier); + userId, identifier, clientType, organizationIdsList); await ClientFor(GetComb(deviceId)).CreateOrUpdateInstallationAsync(installation); if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) @@ -97,7 +94,7 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService } private void BuildInstallationTemplate(Installation installation, string templateId, string templateBody, - string userId, string identifier) + string userId, string identifier, ClientType clientType, List organizationIds) { if (templateBody == null) { @@ -111,8 +108,7 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService Body = templateBody, Tags = new List { - fullTemplateId, - $"{fullTemplateId}_userId:{userId}" + fullTemplateId, $"{fullTemplateId}_userId:{userId}", $"clientType:{clientType}" } }; @@ -121,6 +117,11 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService template.Tags.Add($"{fullTemplateId}_deviceIdentifier:{identifier}"); } + foreach (var organizationId in organizationIds) + { + template.Tags.Add($"organizationId:{organizationId}"); + } + installation.Templates.Add(fullTemplateId, template); } @@ -197,7 +198,7 @@ public class NotificationHubPushRegistrationService : IPushRegistrationService } } - private NotificationHubClient ClientFor(Guid deviceId) + private INotificationHubClient ClientFor(Guid deviceId) { return _notificationHubPool.ClientFor(deviceId); } diff --git a/src/Core/Platform/Push/Services/AzureQueuePushNotificationService.cs b/src/Core/Platform/Push/Services/AzureQueuePushNotificationService.cs index 33272ce870..d3509c5437 100644 --- a/src/Core/Platform/Push/Services/AzureQueuePushNotificationService.cs +++ b/src/Core/Platform/Push/Services/AzureQueuePushNotificationService.cs @@ -5,26 +5,25 @@ using Bit.Core.Auth.Entities; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Models; -using Bit.Core.Settings; +using Bit.Core.NotificationCenter.Entities; using Bit.Core.Tools.Entities; using Bit.Core.Utilities; using Bit.Core.Vault.Entities; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; namespace Bit.Core.Platform.Push.Internal; public class AzureQueuePushNotificationService : IPushNotificationService { private readonly QueueClient _queueClient; - private readonly GlobalSettings _globalSettings; private readonly IHttpContextAccessor _httpContextAccessor; public AzureQueuePushNotificationService( - GlobalSettings globalSettings, + [FromKeyedServices("notifications")] QueueClient queueClient, IHttpContextAccessor httpContextAccessor) { - _queueClient = new QueueClient(globalSettings.Notifications.ConnectionString, "notifications"); - _globalSettings = globalSettings; + _queueClient = queueClient; _httpContextAccessor = httpContextAccessor; } @@ -129,11 +128,7 @@ public class AzureQueuePushNotificationService : IPushNotificationService private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false) { - var message = new UserPushNotification - { - UserId = userId, - Date = DateTime.UtcNow - }; + var message = new UserPushNotification { UserId = userId, Date = DateTime.UtcNow }; await SendMessageAsync(type, message, excludeCurrentContext); } @@ -150,11 +145,7 @@ public class AzureQueuePushNotificationService : IPushNotificationService private async Task PushAuthRequestAsync(AuthRequest authRequest, PushType type) { - var message = new AuthRequestPushNotification - { - Id = authRequest.Id, - UserId = authRequest.UserId - }; + var message = new AuthRequestPushNotification { Id = authRequest.Id, UserId = authRequest.UserId }; await SendMessageAsync(type, message, true); } @@ -174,6 +165,20 @@ public class AzureQueuePushNotificationService : IPushNotificationService await PushSendAsync(send, PushType.SyncSendDelete); } + public async Task PushSyncNotificationAsync(Notification notification) + { + var message = new SyncNotificationPushNotification + { + Id = notification.Id, + UserId = notification.UserId, + OrganizationId = notification.OrganizationId, + ClientType = notification.ClientType, + RevisionDate = notification.RevisionDate + }; + + await SendMessageAsync(PushType.SyncNotification, message, true); + } + private async Task PushSendAsync(Send send, PushType type) { if (send.UserId.HasValue) @@ -204,20 +209,20 @@ public class AzureQueuePushNotificationService : IPushNotificationService return null; } - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + var currentContext = + _httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; return currentContext?.DeviceIdentifier; } public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) + string deviceId = null, ClientType? clientType = null) { // Noop return Task.FromResult(0); } public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) + string deviceId = null, ClientType? clientType = null) { // Noop return Task.FromResult(0); diff --git a/src/Core/Platform/Push/Services/IPushNotificationService.cs b/src/Core/Platform/Push/Services/IPushNotificationService.cs index b015c17df2..5e1ab7067e 100644 --- a/src/Core/Platform/Push/Services/IPushNotificationService.cs +++ b/src/Core/Platform/Push/Services/IPushNotificationService.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Entities; using Bit.Core.Enums; +using Bit.Core.NotificationCenter.Entities; using Bit.Core.Tools.Entities; using Bit.Core.Vault.Entities; @@ -23,11 +24,13 @@ public interface IPushNotificationService Task PushSyncSendCreateAsync(Send send); Task PushSyncSendUpdateAsync(Send send); Task PushSyncSendDeleteAsync(Send send); + Task PushSyncNotificationAsync(Notification notification); Task PushAuthRequestAsync(AuthRequest authRequest); Task PushAuthRequestResponseAsync(AuthRequest authRequest); - Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, string deviceId = null); - Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null); Task PushSyncOrganizationStatusAsync(Organization organization); Task PushSyncOrganizationCollectionManagementSettingsAsync(Organization organization); + Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null, ClientType? clientType = null); + Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null, ClientType? clientType = null); } diff --git a/src/Core/Platform/Push/Services/IPushRegistrationService.cs b/src/Core/Platform/Push/Services/IPushRegistrationService.cs index 482e7ae1c4..0c4271f061 100644 --- a/src/Core/Platform/Push/Services/IPushRegistrationService.cs +++ b/src/Core/Platform/Push/Services/IPushRegistrationService.cs @@ -5,7 +5,7 @@ namespace Bit.Core.Platform.Push; public interface IPushRegistrationService { Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type); + string identifier, DeviceType type, IEnumerable organizationIds); Task DeleteRegistrationAsync(string deviceId); Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); diff --git a/src/Core/Platform/Push/Services/MultiServicePushNotificationService.cs b/src/Core/Platform/Push/Services/MultiServicePushNotificationService.cs index f1a5700013..4ad81e223b 100644 --- a/src/Core/Platform/Push/Services/MultiServicePushNotificationService.cs +++ b/src/Core/Platform/Push/Services/MultiServicePushNotificationService.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Entities; using Bit.Core.Enums; +using Bit.Core.NotificationCenter.Entities; using Bit.Core.Settings; using Bit.Core.Tools.Entities; using Bit.Core.Vault.Entities; @@ -131,20 +132,6 @@ public class MultiServicePushNotificationService : IPushNotificationService return Task.FromResult(0); } - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId)); - return Task.FromResult(0); - } - - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId)); - return Task.FromResult(0); - } - public Task PushSyncOrganizationStatusAsync(Organization organization) { PushToServices((s) => s.PushSyncOrganizationStatusAsync(organization)); @@ -157,6 +144,26 @@ public class MultiServicePushNotificationService : IPushNotificationService return Task.CompletedTask; } + public Task PushSyncNotificationAsync(Notification notification) + { + PushToServices((s) => s.PushSyncNotificationAsync(notification)); + return Task.CompletedTask; + } + + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null, ClientType? clientType = null) + { + PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId, clientType)); + return Task.FromResult(0); + } + + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null, ClientType? clientType = null) + { + PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId, clientType)); + return Task.FromResult(0); + } + private void PushToServices(Func pushFunc) { if (_services != null) diff --git a/src/Core/Platform/Push/Services/NoopPushNotificationService.cs b/src/Core/Platform/Push/Services/NoopPushNotificationService.cs index 4a185bee1a..463a2fde88 100644 --- a/src/Core/Platform/Push/Services/NoopPushNotificationService.cs +++ b/src/Core/Platform/Push/Services/NoopPushNotificationService.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Entities; using Bit.Core.Enums; +using Bit.Core.NotificationCenter.Entities; using Bit.Core.Tools.Entities; using Bit.Core.Vault.Entities; @@ -84,7 +85,7 @@ public class NoopPushNotificationService : IPushNotificationService } public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) + string deviceId = null, ClientType? clientType = null) { return Task.FromResult(0); } @@ -107,8 +108,10 @@ public class NoopPushNotificationService : IPushNotificationService } public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) + string deviceId = null, ClientType? clientType = null) { return Task.FromResult(0); } + + public Task PushSyncNotificationAsync(Notification notification) => Task.CompletedTask; } diff --git a/src/Core/Platform/Push/Services/NoopPushRegistrationService.cs b/src/Core/Platform/Push/Services/NoopPushRegistrationService.cs index 6d1716a6ce..6bcf9e893a 100644 --- a/src/Core/Platform/Push/Services/NoopPushRegistrationService.cs +++ b/src/Core/Platform/Push/Services/NoopPushRegistrationService.cs @@ -10,7 +10,7 @@ public class NoopPushRegistrationService : IPushRegistrationService } public Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type) + string identifier, DeviceType type, IEnumerable organizationIds) { return Task.FromResult(0); } diff --git a/src/Core/Platform/Push/Services/NotificationsApiPushNotificationService.cs b/src/Core/Platform/Push/Services/NotificationsApiPushNotificationService.cs index 5ebfc811ef..5c6b46f63e 100644 --- a/src/Core/Platform/Push/Services/NotificationsApiPushNotificationService.cs +++ b/src/Core/Platform/Push/Services/NotificationsApiPushNotificationService.cs @@ -3,6 +3,7 @@ using Bit.Core.Auth.Entities; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Models; +using Bit.Core.NotificationCenter.Entities; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Tools.Entities; @@ -183,6 +184,20 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService await PushSendAsync(send, PushType.SyncSendDelete); } + public async Task PushSyncNotificationAsync(Notification notification) + { + var message = new SyncNotificationPushNotification + { + Id = notification.Id, + UserId = notification.UserId, + OrganizationId = notification.OrganizationId, + ClientType = notification.ClientType, + RevisionDate = notification.RevisionDate + }; + + await SendMessageAsync(PushType.SyncNotification, message, true); + } + private async Task PushSendAsync(Send send, PushType type) { if (send.UserId.HasValue) @@ -212,20 +227,20 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService return null; } - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + var currentContext = + _httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; return currentContext?.DeviceIdentifier; } public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) + string deviceId = null, ClientType? clientType = null) { // Noop return Task.FromResult(0); } public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) + string deviceId = null, ClientType? clientType = null) { // Noop return Task.FromResult(0); diff --git a/src/Core/Platform/Push/Services/RelayPushNotificationService.cs b/src/Core/Platform/Push/Services/RelayPushNotificationService.cs index 6549ab47c3..f51ab004a6 100644 --- a/src/Core/Platform/Push/Services/RelayPushNotificationService.cs +++ b/src/Core/Platform/Push/Services/RelayPushNotificationService.cs @@ -5,6 +5,7 @@ using Bit.Core.Enums; using Bit.Core.IdentityServer; using Bit.Core.Models; using Bit.Core.Models.Api; +using Bit.Core.NotificationCenter.Entities; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -138,11 +139,7 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti private async Task PushUserAsync(Guid userId, PushType type, bool excludeCurrentContext = false) { - var message = new UserPushNotification - { - UserId = userId, - Date = DateTime.UtcNow - }; + var message = new UserPushNotification { UserId = userId, Date = DateTime.UtcNow }; await SendPayloadToUserAsync(userId, type, message, excludeCurrentContext); } @@ -189,69 +186,32 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti private async Task PushAuthRequestAsync(AuthRequest authRequest, PushType type) { - var message = new AuthRequestPushNotification - { - Id = authRequest.Id, - UserId = authRequest.UserId - }; + var message = new AuthRequestPushNotification { Id = authRequest.Id, UserId = authRequest.UserId }; await SendPayloadToUserAsync(authRequest.UserId, type, message, true); } - private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) + public async Task PushSyncNotificationAsync(Notification notification) { - var request = new PushSendRequestModel + var message = new SyncNotificationPushNotification { - UserId = userId.ToString(), - Type = type, - Payload = payload + Id = notification.Id, + UserId = notification.UserId, + OrganizationId = notification.OrganizationId, + ClientType = notification.ClientType, + RevisionDate = notification.RevisionDate }; - await AddCurrentContextAsync(request, excludeCurrentContext); - await SendAsync(HttpMethod.Post, "push/send", request); - } - - private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) - { - var request = new PushSendRequestModel + if (notification.UserId.HasValue) { - OrganizationId = orgId.ToString(), - Type = type, - Payload = payload - }; - - await AddCurrentContextAsync(request, excludeCurrentContext); - await SendAsync(HttpMethod.Post, "push/send", request); - } - - private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier) - { - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; - if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier)) - { - var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier); - if (device != null) - { - request.DeviceId = device.Id.ToString(); - } - if (addIdentifier) - { - request.Identifier = currentContext.DeviceIdentifier; - } + await SendPayloadToUserAsync(notification.UserId.Value, PushType.SyncNotification, message, true, + notification.ClientType); + } + else if (notification.OrganizationId.HasValue) + { + await SendPayloadToOrganizationAsync(notification.OrganizationId.Value, PushType.SyncNotification, message, + true, notification.ClientType); } - } - - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - throw new NotImplementedException(); - } - - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - throw new NotImplementedException(); } public async Task PushSyncOrganizationStatusAsync(Organization organization) @@ -278,4 +238,65 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti }, false ); + + private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext, + ClientType? clientType = null) + { + var request = new PushSendRequestModel + { + UserId = userId.ToString(), + Type = type, + Payload = payload, + ClientType = clientType + }; + + await AddCurrentContextAsync(request, excludeCurrentContext); + await SendAsync(HttpMethod.Post, "push/send", request); + } + + private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, + bool excludeCurrentContext, ClientType? clientType = null) + { + var request = new PushSendRequestModel + { + OrganizationId = orgId.ToString(), + Type = type, + Payload = payload, + ClientType = clientType + }; + + await AddCurrentContextAsync(request, excludeCurrentContext); + await SendAsync(HttpMethod.Post, "push/send", request); + } + + private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier) + { + var currentContext = + _httpContextAccessor?.HttpContext?.RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier)) + { + var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier); + if (device != null) + { + request.DeviceId = device.Id.ToString(); + } + + if (addIdentifier) + { + request.Identifier = currentContext.DeviceIdentifier; + } + } + } + + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null, ClientType? clientType = null) + { + throw new NotImplementedException(); + } + + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null, ClientType? clientType = null) + { + throw new NotImplementedException(); + } } diff --git a/src/Core/Platform/Push/Services/RelayPushRegistrationService.cs b/src/Core/Platform/Push/Services/RelayPushRegistrationService.cs index 79b033e877..b838fbde59 100644 --- a/src/Core/Platform/Push/Services/RelayPushRegistrationService.cs +++ b/src/Core/Platform/Push/Services/RelayPushRegistrationService.cs @@ -25,7 +25,7 @@ public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegi } public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type) + string identifier, DeviceType type, IEnumerable organizationIds) { var requestModel = new PushRegistrationRequestModel { @@ -33,7 +33,8 @@ public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegi Identifier = identifier, PushToken = pushToken, Type = type, - UserId = userId + UserId = userId, + OrganizationIds = organizationIds }; await SendAsync(HttpMethod.Post, "push/register", requestModel); } diff --git a/src/Core/Services/Implementations/DeviceService.cs b/src/Core/Services/Implementations/DeviceService.cs index afbc574417..28823eeda7 100644 --- a/src/Core/Services/Implementations/DeviceService.cs +++ b/src/Core/Services/Implementations/DeviceService.cs @@ -1,6 +1,7 @@ using Bit.Core.Auth.Models.Api.Request; using Bit.Core.Auth.Utilities; using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Platform.Push; using Bit.Core.Repositories; @@ -11,13 +12,16 @@ public class DeviceService : IDeviceService { private readonly IDeviceRepository _deviceRepository; private readonly IPushRegistrationService _pushRegistrationService; + private readonly IOrganizationUserRepository _organizationUserRepository; public DeviceService( IDeviceRepository deviceRepository, - IPushRegistrationService pushRegistrationService) + IPushRegistrationService pushRegistrationService, + IOrganizationUserRepository organizationUserRepository) { _deviceRepository = deviceRepository; _pushRegistrationService = pushRegistrationService; + _organizationUserRepository = organizationUserRepository; } public async Task SaveAsync(Device device) @@ -32,8 +36,13 @@ public class DeviceService : IDeviceService await _deviceRepository.ReplaceAsync(device); } + var organizationIdsString = + (await _organizationUserRepository.GetManyDetailsByUserAsync(device.UserId, + OrganizationUserStatusType.Confirmed)) + .Select(ou => ou.OrganizationId.ToString()); + await _pushRegistrationService.CreateOrUpdateRegistrationAsync(device.PushToken, device.Id.ToString(), - device.UserId.ToString(), device.Identifier, device.Type); + device.UserId.ToString(), device.Identifier, device.Type, organizationIdsString); } public async Task ClearTokenAsync(Device device) diff --git a/src/Identity/IdentityServer/ApiResources.cs b/src/Identity/IdentityServer/ApiResources.cs index a0712aafe7..f969d67908 100644 --- a/src/Identity/IdentityServer/ApiResources.cs +++ b/src/Identity/IdentityServer/ApiResources.cs @@ -18,6 +18,7 @@ public class ApiResources Claims.SecurityStamp, Claims.Premium, Claims.Device, + Claims.DeviceType, Claims.OrganizationOwner, Claims.OrganizationAdmin, Claims.OrganizationUser, diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index ea207a7aaa..5e78212cf1 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -210,6 +210,7 @@ public abstract class BaseRequestValidator where T : class if (device != null) { claims.Add(new Claim(Claims.Device, device.Identifier)); + claims.Add(new Claim(Claims.DeviceType, device.Type.ToString())); } var customResponse = new Dictionary(); diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index 6f49822dc9..6d17ca9955 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -10,6 +10,8 @@ public static class HubHelpers private static JsonSerializerOptions _deserializerOptions = new JsonSerializerOptions { PropertyNameCaseInsensitive = true }; + private static readonly string _receiveMessageMethod = "ReceiveMessage"; + public static async Task SendNotificationToHubAsync( string notificationJson, IHubContext hubContext, @@ -18,7 +20,8 @@ public static class HubHelpers CancellationToken cancellationToken = default(CancellationToken) ) { - var notification = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); + var notification = + JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); logger.LogInformation("Sending notification: {NotificationType}", notification.Type); switch (notification.Type) { @@ -32,14 +35,15 @@ public static class HubHelpers if (cipherNotification.Payload.UserId.HasValue) { await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); + .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } else if (cipherNotification.Payload.OrganizationId.HasValue) { - await hubContext.Clients.Group( - $"Organization_{cipherNotification.Payload.OrganizationId}") - .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); + await hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(cipherNotification.Payload.OrganizationId.Value)) + .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } + break; case PushType.SyncFolderUpdate: case PushType.SyncFolderCreate: @@ -48,7 +52,7 @@ public static class HubHelpers JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", folderNotification, cancellationToken); + .SendAsync(_receiveMessageMethod, folderNotification, cancellationToken); break; case PushType.SyncCiphers: case PushType.SyncVault: @@ -60,30 +64,30 @@ public static class HubHelpers JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", userNotification, cancellationToken); + .SendAsync(_receiveMessageMethod, userNotification, cancellationToken); break; case PushType.SyncSendCreate: case PushType.SyncSendUpdate: case PushType.SyncSendDelete: var sendNotification = JsonSerializer.Deserialize>( - notificationJson, _deserializerOptions); + notificationJson, _deserializerOptions); await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", sendNotification, cancellationToken); + .SendAsync(_receiveMessageMethod, sendNotification, cancellationToken); break; case PushType.AuthRequestResponse: var authRequestResponseNotification = JsonSerializer.Deserialize>( - notificationJson, _deserializerOptions); + notificationJson, _deserializerOptions); await anonymousHubContext.Clients.Group(authRequestResponseNotification.Payload.Id.ToString()) .SendAsync("AuthRequestResponseRecieved", authRequestResponseNotification, cancellationToken); break; case PushType.AuthRequest: var authRequestNotification = JsonSerializer.Deserialize>( - notificationJson, _deserializerOptions); + notificationJson, _deserializerOptions); await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", authRequestNotification, cancellationToken); + .SendAsync(_receiveMessageMethod, authRequestNotification, cancellationToken); break; case PushType.SyncOrganizationStatusChanged: var orgStatusNotification = @@ -99,6 +103,32 @@ public static class HubHelpers await hubContext.Clients.Group($"Organization_{organizationCollectionSettingsChangedNotification.Payload.OrganizationId}") .SendAsync("ReceiveMessage", organizationCollectionSettingsChangedNotification, cancellationToken); break; + case PushType.SyncNotification: + var syncNotification = + JsonSerializer.Deserialize>( + notificationJson, _deserializerOptions); + if (syncNotification.Payload.UserId.HasValue) + { + if (syncNotification.Payload.ClientType == ClientType.All) + { + await hubContext.Clients.User(syncNotification.Payload.UserId.ToString()) + .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken); + } + else + { + await hubContext.Clients.Group(NotificationsHub.GetUserGroup( + syncNotification.Payload.UserId.Value, syncNotification.Payload.ClientType)) + .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken); + } + } + else if (syncNotification.Payload.OrganizationId.HasValue) + { + await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup( + syncNotification.Payload.OrganizationId.Value, syncNotification.Payload.ClientType)) + .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken); + } + + break; default: break; } diff --git a/src/Notifications/NotificationsHub.cs b/src/Notifications/NotificationsHub.cs index a86cf329c5..27cd19c0a0 100644 --- a/src/Notifications/NotificationsHub.cs +++ b/src/Notifications/NotificationsHub.cs @@ -1,5 +1,7 @@ using Bit.Core.Context; +using Bit.Core.Enums; using Bit.Core.Settings; +using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; namespace Bit.Notifications; @@ -20,13 +22,25 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub { var currentContext = new CurrentContext(null, null); await currentContext.BuildAsync(Context.User, _globalSettings); + + var clientType = DeviceTypes.ToClientType(currentContext.DeviceType); + if (clientType != ClientType.All && currentContext.UserId.HasValue) + { + await Groups.AddToGroupAsync(Context.ConnectionId, GetUserGroup(currentContext.UserId.Value, clientType)); + } + if (currentContext.Organizations != null) { foreach (var org in currentContext.Organizations) { - await Groups.AddToGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); + await Groups.AddToGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id)); + if (clientType != ClientType.All) + { + await Groups.AddToGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id, clientType)); + } } } + _connectionCounter.Increment(); await base.OnConnectedAsync(); } @@ -35,14 +49,39 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub { var currentContext = new CurrentContext(null, null); await currentContext.BuildAsync(Context.User, _globalSettings); + + var clientType = DeviceTypes.ToClientType(currentContext.DeviceType); + if (clientType != ClientType.All && currentContext.UserId.HasValue) + { + await Groups.RemoveFromGroupAsync(Context.ConnectionId, + GetUserGroup(currentContext.UserId.Value, clientType)); + } + if (currentContext.Organizations != null) { foreach (var org in currentContext.Organizations) { - await Groups.RemoveFromGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); + await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id)); + if (clientType != ClientType.All) + { + await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id, clientType)); + } } } + _connectionCounter.Decrement(); await base.OnDisconnectedAsync(exception); } + + public static string GetUserGroup(Guid userId, ClientType clientType) + { + return $"UserClientType_{userId}_{clientType}"; + } + + public static string GetOrganizationGroup(Guid organizationId, ClientType? clientType = null) + { + return clientType is null or ClientType.All + ? $"Organization_{organizationId}" + : $"OrganizationClientType_{organizationId}_{clientType}"; + } } diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index 192871bffc..5a1205c961 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -3,6 +3,7 @@ using System.Reflection; using System.Security.Claims; using System.Security.Cryptography.X509Certificates; using AspNetCoreRateLimit; +using Azure.Storage.Queues; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Services; @@ -306,7 +307,10 @@ public static class ServiceCollectionExtensions services.AddKeyedSingleton("implementation"); if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) { - services.AddKeyedSingleton("implementation"); + services.AddKeyedSingleton("notifications", + (_, _) => new QueueClient(globalSettings.Notifications.ConnectionString, "notifications")); + services.AddKeyedSingleton( + "implementation"); } } diff --git a/test/Api.IntegrationTest/NotificationCenter/Controllers/NotificationsControllerTests.cs b/test/Api.IntegrationTest/NotificationCenter/Controllers/NotificationsControllerTests.cs index 6d487c5d8f..ca04c9775d 100644 --- a/test/Api.IntegrationTest/NotificationCenter/Controllers/NotificationsControllerTests.cs +++ b/test/Api.IntegrationTest/NotificationCenter/Controllers/NotificationsControllerTests.cs @@ -133,12 +133,10 @@ public class NotificationsControllerTests : IClassFixture [InlineData(null, null, "2", 10)] [InlineData(10, null, "2", 10)] [InlineData(10, 2, "3", 10)] - [InlineData(10, 3, null, 0)] - [InlineData(15, null, "2", 15)] - [InlineData(15, 2, null, 5)] - [InlineData(20, null, "2", 20)] - [InlineData(20, 2, null, 0)] - [InlineData(1000, null, null, 20)] + [InlineData(10, 3, null, 4)] + [InlineData(24, null, "2", 24)] + [InlineData(24, 2, null, 0)] + [InlineData(1000, null, null, 24)] public async Task ListAsync_PaginationFilter_ReturnsNextPageOfNotificationsCorrectOrder( int? pageSize, int? pageNumber, string? expectedContinuationToken, int expectedCount) { @@ -505,11 +503,12 @@ public class NotificationsControllerTests : IClassFixture userPartOrOrganizationNotificationWithStatuses } .SelectMany(n => n) + .Where(n => n.Item1.ClientType is ClientType.All or ClientType.Web) .ToList(); } private async Task> CreateNotificationsAsync(Guid? userId = null, Guid? organizationId = null, - int numberToCreate = 5) + int numberToCreate = 3) { var priorities = Enum.GetValues(); var clientTypes = Enum.GetValues(); @@ -570,13 +569,9 @@ public class NotificationsControllerTests : IClassFixture DeletedDate = DateTime.UtcNow - TimeSpan.FromMinutes(_random.Next(3600)) }); - return - [ - (notifications[0], readDateNotificationStatus), - (notifications[1], deletedDateNotificationStatus), - (notifications[2], readDateAndDeletedDateNotificationStatus), - (notifications[3], null), - (notifications[4], null) - ]; + List statuses = + [readDateNotificationStatus, deletedDateNotificationStatus, readDateAndDeletedDateNotificationStatus]; + + return notifications.Select(n => (n, statuses.Find(s => s.NotificationId == n.Id))).ToList(); } } diff --git a/test/Core.Test/AutoFixture/QueueClientFixtures.cs b/test/Core.Test/AutoFixture/QueueClientFixtures.cs new file mode 100644 index 0000000000..2a722f3853 --- /dev/null +++ b/test/Core.Test/AutoFixture/QueueClientFixtures.cs @@ -0,0 +1,35 @@ +#nullable enable +using AutoFixture; +using AutoFixture.Kernel; +using Azure.Storage.Queues; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; + +namespace Bit.Core.Test.AutoFixture; + +public class QueueClientBuilder : ISpecimenBuilder +{ + public object Create(object request, ISpecimenContext context) + { + var type = request as Type; + if (type == typeof(QueueClient)) + { + return Substitute.For(); + } + + return new NoSpecimen(); + } +} + +public class QueueClientCustomizeAttribute : BitCustomizeAttribute +{ + public override ICustomization GetCustomization() => new QueueClientFixtures(); +} + +public class QueueClientFixtures : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new QueueClientBuilder()); + } +} diff --git a/test/Core.Test/Models/Api/Request/PushSendRequestModelTests.cs b/test/Core.Test/Models/Api/Request/PushSendRequestModelTests.cs new file mode 100644 index 0000000000..41a6c25bf2 --- /dev/null +++ b/test/Core.Test/Models/Api/Request/PushSendRequestModelTests.cs @@ -0,0 +1,94 @@ +#nullable enable +using System.ComponentModel.DataAnnotations; +using System.Text.Json; +using Bit.Core.Enums; +using Bit.Core.Models.Api; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Core.Test.Models.Api.Request; + +public class PushSendRequestModelTests +{ + [Theory] + [InlineData(null, null)] + [InlineData(null, "")] + [InlineData(null, " ")] + [InlineData("", null)] + [InlineData(" ", null)] + [InlineData("", "")] + [InlineData(" ", " ")] + public void Validate_UserIdOrganizationIdNullOrEmpty_Invalid(string? userId, string? organizationId) + { + var model = new PushSendRequestModel + { + UserId = userId, + OrganizationId = organizationId, + Type = PushType.SyncCiphers, + Payload = "test" + }; + + var results = Validate(model); + + Assert.Single(results); + Assert.Contains(results, result => result.ErrorMessage == "UserId or OrganizationId is required."); + } + + [Theory] + [BitAutoData("Payload")] + [BitAutoData("Type")] + public void Validate_RequiredFieldNotProvided_Invalid(string requiredField) + { + var model = new PushSendRequestModel + { + UserId = Guid.NewGuid().ToString(), + OrganizationId = Guid.NewGuid().ToString(), + Type = PushType.SyncCiphers, + Payload = "test" + }; + + var dictionary = new Dictionary(); + foreach (var property in model.GetType().GetProperties()) + { + if (property.Name == requiredField) + { + continue; + } + + dictionary[property.Name] = property.GetValue(model); + } + + var serialized = JsonSerializer.Serialize(dictionary, JsonHelpers.IgnoreWritingNull); + var jsonException = + Assert.Throws(() => JsonSerializer.Deserialize(serialized)); + Assert.Contains($"missing required properties, including the following: {requiredField}", + jsonException.Message); + } + + [Fact] + public void Validate_AllFieldsPresent_Valid() + { + var model = new PushSendRequestModel + { + UserId = Guid.NewGuid().ToString(), + OrganizationId = Guid.NewGuid().ToString(), + Type = PushType.SyncCiphers, + Payload = "test payload", + Identifier = Guid.NewGuid().ToString(), + ClientType = ClientType.All, + DeviceId = Guid.NewGuid().ToString() + }; + + var results = Validate(model); + + Assert.Empty(results); + } + + private static List Validate(PushSendRequestModel model) + { + var results = new List(); + Validator.TryValidateObject(model, new ValidationContext(model), results, true); + return results; + } +} diff --git a/test/Core.Test/NotificationCenter/Commands/CreateNotificationCommandTest.cs b/test/Core.Test/NotificationCenter/Commands/CreateNotificationCommandTest.cs index 4f5842d1c7..a51feb6a73 100644 --- a/test/Core.Test/NotificationCenter/Commands/CreateNotificationCommandTest.cs +++ b/test/Core.Test/NotificationCenter/Commands/CreateNotificationCommandTest.cs @@ -5,6 +5,7 @@ using Bit.Core.NotificationCenter.Authorization; using Bit.Core.NotificationCenter.Commands; using Bit.Core.NotificationCenter.Entities; using Bit.Core.NotificationCenter.Repositories; +using Bit.Core.Platform.Push; using Bit.Core.Test.NotificationCenter.AutoFixture; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -55,5 +56,8 @@ public class CreateNotificationCommandTest Assert.Equal(notification, newNotification); Assert.Equal(DateTime.UtcNow, notification.CreationDate, TimeSpan.FromMinutes(1)); Assert.Equal(notification.CreationDate, notification.RevisionDate); + await sutProvider.GetDependency() + .Received(1) + .PushSyncNotificationAsync(newNotification); } } diff --git a/test/Core.Test/NotificationHub/NotificationHubPushNotificationServiceTests.cs b/test/Core.Test/NotificationHub/NotificationHubPushNotificationServiceTests.cs index c26fc23460..dc391b9801 100644 --- a/test/Core.Test/NotificationHub/NotificationHubPushNotificationServiceTests.cs +++ b/test/Core.Test/NotificationHub/NotificationHubPushNotificationServiceTests.cs @@ -1,42 +1,236 @@ -using Bit.Core.NotificationHub; -using Bit.Core.Platform.Push; +#nullable enable +using System.Text.Json; +using Bit.Core.Enums; +using Bit.Core.Models; +using Bit.Core.Models.Data; +using Bit.Core.NotificationCenter.Entities; +using Bit.Core.NotificationHub; using Bit.Core.Repositories; -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Logging; +using Bit.Core.Test.NotificationCenter.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; namespace Bit.Core.Test.NotificationHub; +[SutProviderCustomize] public class NotificationHubPushNotificationServiceTests { - private readonly NotificationHubPushNotificationService _sut; - - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly INotificationHubPool _notificationHubPool; - private readonly IHttpContextAccessor _httpContextAccessor; - private readonly ILogger _logger; - - public NotificationHubPushNotificationServiceTests() + [Theory] + [BitAutoData] + [NotificationCustomize] + public async void PushSyncNotificationAsync_Global_NotSent( + SutProvider sutProvider, Notification notification) { - _installationDeviceRepository = Substitute.For(); - _httpContextAccessor = Substitute.For(); - _notificationHubPool = Substitute.For(); - _logger = Substitute.For>(); + await sutProvider.Sut.PushSyncNotificationAsync(notification); - _sut = new NotificationHubPushNotificationService( - _installationDeviceRepository, - _notificationHubPool, - _httpContextAccessor, - _logger - ); + await sutProvider.GetDependency() + .Received(0) + .AllClients + .Received(0) + .SendTemplateNotificationAsync(Arg.Any>(), Arg.Any()); + await sutProvider.GetDependency() + .Received(0) + .UpsertAsync(Arg.Any()); } - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() + [Theory] + [BitAutoData(false)] + [BitAutoData(true)] + [NotificationCustomize(false)] + public async void PushSyncNotificationAsync_UserIdProvidedClientTypeAll_SentToUser( + bool organizationIdNull, SutProvider sutProvider, + Notification notification) { - Assert.NotNull(_sut); + if (organizationIdNull) + { + notification.OrganizationId = null; + } + + notification.ClientType = ClientType.All; + var expectedSyncNotification = ToSyncNotificationPushNotification(notification); + + await sutProvider.Sut.PushSyncNotificationAsync(notification); + + await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification, + $"(template:payload_userId:{notification.UserId})"); + await sutProvider.GetDependency() + .Received(0) + .UpsertAsync(Arg.Any()); + } + + [Theory] + [BitAutoData(false, ClientType.Browser)] + [BitAutoData(false, ClientType.Desktop)] + [BitAutoData(false, ClientType.Web)] + [BitAutoData(false, ClientType.Mobile)] + [BitAutoData(true, ClientType.Browser)] + [BitAutoData(true, ClientType.Desktop)] + [BitAutoData(true, ClientType.Web)] + [BitAutoData(true, ClientType.Mobile)] + [NotificationCustomize(false)] + public async void PushSyncNotificationAsync_UserIdProvidedClientTypeNotAll_SentToUser(bool organizationIdNull, + ClientType clientType, SutProvider sutProvider, + Notification notification) + { + if (organizationIdNull) + { + notification.OrganizationId = null; + } + + notification.ClientType = clientType; + var expectedSyncNotification = ToSyncNotificationPushNotification(notification); + + await sutProvider.Sut.PushSyncNotificationAsync(notification); + + await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification, + $"(template:payload_userId:{notification.UserId} && clientType:{clientType})"); + await sutProvider.GetDependency() + .Received(0) + .UpsertAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + [NotificationCustomize(false)] + public async void PushSyncNotificationAsync_UserIdNullOrganizationIdProvidedClientTypeAll_SentToOrganization( + SutProvider sutProvider, Notification notification) + { + notification.UserId = null; + notification.ClientType = ClientType.All; + var expectedSyncNotification = ToSyncNotificationPushNotification(notification); + + await sutProvider.Sut.PushSyncNotificationAsync(notification); + + await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification, + $"(template:payload && organizationId:{notification.OrganizationId})"); + await sutProvider.GetDependency() + .Received(0) + .UpsertAsync(Arg.Any()); + } + + [Theory] + [BitAutoData(ClientType.Browser)] + [BitAutoData(ClientType.Desktop)] + [BitAutoData(ClientType.Web)] + [BitAutoData(ClientType.Mobile)] + [NotificationCustomize(false)] + public async void PushSyncNotificationAsync_UserIdNullOrganizationIdProvidedClientTypeNotAll_SentToOrganization( + ClientType clientType, SutProvider sutProvider, + Notification notification) + { + notification.UserId = null; + notification.ClientType = clientType; + + var expectedSyncNotification = ToSyncNotificationPushNotification(notification); + + await sutProvider.Sut.PushSyncNotificationAsync(notification); + + await AssertSendTemplateNotificationAsync(sutProvider, PushType.SyncNotification, expectedSyncNotification, + $"(template:payload && organizationId:{notification.OrganizationId} && clientType:{clientType})"); + await sutProvider.GetDependency() + .Received(0) + .UpsertAsync(Arg.Any()); + } + + [Theory] + [BitAutoData([null])] + [BitAutoData(ClientType.All)] + public async void SendPayloadToUserAsync_ClientTypeNullOrAll_SentToUser(ClientType? clientType, + SutProvider sutProvider, Guid userId, PushType pushType, string payload, + string identifier) + { + await sutProvider.Sut.SendPayloadToUserAsync(userId.ToString(), pushType, payload, identifier, null, + clientType); + + await AssertSendTemplateNotificationAsync(sutProvider, pushType, payload, + $"(template:payload_userId:{userId} && !deviceIdentifier:{identifier})"); + await sutProvider.GetDependency() + .Received(0) + .UpsertAsync(Arg.Any()); + } + + [Theory] + [BitAutoData(ClientType.Browser)] + [BitAutoData(ClientType.Desktop)] + [BitAutoData(ClientType.Mobile)] + [BitAutoData(ClientType.Web)] + public async void SendPayloadToUserAsync_ClientTypeExplicit_SentToUserAndClientType(ClientType clientType, + SutProvider sutProvider, Guid userId, PushType pushType, string payload, + string identifier) + { + await sutProvider.Sut.SendPayloadToUserAsync(userId.ToString(), pushType, payload, identifier, null, + clientType); + + await AssertSendTemplateNotificationAsync(sutProvider, pushType, payload, + $"(template:payload_userId:{userId} && !deviceIdentifier:{identifier} && clientType:{clientType})"); + await sutProvider.GetDependency() + .Received(0) + .UpsertAsync(Arg.Any()); + } + + [Theory] + [BitAutoData([null])] + [BitAutoData(ClientType.All)] + public async void SendPayloadToOrganizationAsync_ClientTypeNullOrAll_SentToOrganization(ClientType? clientType, + SutProvider sutProvider, Guid organizationId, PushType pushType, + string payload, string identifier) + { + await sutProvider.Sut.SendPayloadToOrganizationAsync(organizationId.ToString(), pushType, payload, identifier, + null, clientType); + + await AssertSendTemplateNotificationAsync(sutProvider, pushType, payload, + $"(template:payload && organizationId:{organizationId} && !deviceIdentifier:{identifier})"); + await sutProvider.GetDependency() + .Received(0) + .UpsertAsync(Arg.Any()); + } + + [Theory] + [BitAutoData(ClientType.Browser)] + [BitAutoData(ClientType.Desktop)] + [BitAutoData(ClientType.Mobile)] + [BitAutoData(ClientType.Web)] + public async void SendPayloadToOrganizationAsync_ClientTypeExplicit_SentToOrganizationAndClientType( + ClientType clientType, SutProvider sutProvider, Guid organizationId, + PushType pushType, string payload, string identifier) + { + await sutProvider.Sut.SendPayloadToOrganizationAsync(organizationId.ToString(), pushType, payload, identifier, + null, clientType); + + await AssertSendTemplateNotificationAsync(sutProvider, pushType, payload, + $"(template:payload && organizationId:{organizationId} && !deviceIdentifier:{identifier} && clientType:{clientType})"); + await sutProvider.GetDependency() + .Received(0) + .UpsertAsync(Arg.Any()); + } + + private static SyncNotificationPushNotification ToSyncNotificationPushNotification(Notification notification) => + new() + { + Id = notification.Id, + UserId = notification.UserId, + OrganizationId = notification.OrganizationId, + ClientType = notification.ClientType, + RevisionDate = notification.RevisionDate + }; + + private static async Task AssertSendTemplateNotificationAsync( + SutProvider sutProvider, PushType type, object payload, string tag) + { + await sutProvider.GetDependency() + .Received(1) + .AllClients + .Received(1) + .SendTemplateNotificationAsync( + Arg.Is>(dictionary => MatchingSendPayload(dictionary, type, payload)), + tag); + } + + private static bool MatchingSendPayload(IDictionary dictionary, PushType type, object payload) + { + return dictionary.ContainsKey("type") && dictionary["type"].Equals(((byte)type).ToString()) && + dictionary.ContainsKey("payload") && dictionary["payload"].Equals(JsonSerializer.Serialize(payload)); } } diff --git a/test/Core.Test/NotificationHub/NotificationHubPushRegistrationServiceTests.cs b/test/Core.Test/NotificationHub/NotificationHubPushRegistrationServiceTests.cs index c5851f2791..d51df9c882 100644 --- a/test/Core.Test/NotificationHub/NotificationHubPushRegistrationServiceTests.cs +++ b/test/Core.Test/NotificationHub/NotificationHubPushRegistrationServiceTests.cs @@ -1,44 +1,290 @@ -using Bit.Core.NotificationHub; -using Bit.Core.Repositories; -using Bit.Core.Settings; -using Microsoft.Extensions.Logging; +#nullable enable +using Bit.Core.Enums; +using Bit.Core.NotificationHub; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Azure.NotificationHubs; using NSubstitute; using Xunit; namespace Bit.Core.Test.NotificationHub; +[SutProviderCustomize] public class NotificationHubPushRegistrationServiceTests { - private readonly NotificationHubPushRegistrationService _sut; - - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly IServiceProvider _serviceProvider; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - private readonly INotificationHubPool _notificationHubPool; - - public NotificationHubPushRegistrationServiceTests() + [Theory] + [BitAutoData([null])] + [BitAutoData("")] + [BitAutoData(" ")] + public async Task CreateOrUpdateRegistrationAsync_PushTokenNullOrEmpty_InstallationNotCreated(string? pushToken, + SutProvider sutProvider, Guid deviceId, Guid userId, Guid identifier, + Guid organizationId) { - _installationDeviceRepository = Substitute.For(); - _serviceProvider = Substitute.For(); - _logger = Substitute.For>(); - _globalSettings = new GlobalSettings(); - _notificationHubPool = Substitute.For(); + await sutProvider.Sut.CreateOrUpdateRegistrationAsync(pushToken, deviceId.ToString(), userId.ToString(), + identifier.ToString(), DeviceType.Android, [organizationId.ToString()]); - _sut = new NotificationHubPushRegistrationService( - _installationDeviceRepository, - _globalSettings, - _notificationHubPool, - _serviceProvider, - _logger - ); + sutProvider.GetDependency() + .Received(0) + .ClientFor(deviceId); } - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() + [Theory] + [BitAutoData(false, false)] + [BitAutoData(false, true)] + [BitAutoData(true, false)] + [BitAutoData(true, true)] + public async Task CreateOrUpdateRegistrationAsync_DeviceTypeAndroid_InstallationCreated(bool identifierNull, + bool partOfOrganizationId, SutProvider sutProvider, Guid deviceId, + Guid userId, Guid? identifier, Guid organizationId) { - Assert.NotNull(_sut); + var notificationHubClient = Substitute.For(); + sutProvider.GetDependency().ClientFor(Arg.Any()).Returns(notificationHubClient); + + var pushToken = "test push token"; + + await sutProvider.Sut.CreateOrUpdateRegistrationAsync(pushToken, deviceId.ToString(), userId.ToString(), + identifierNull ? null : identifier.ToString(), DeviceType.Android, + partOfOrganizationId ? [organizationId.ToString()] : []); + + sutProvider.GetDependency() + .Received(1) + .ClientFor(deviceId); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => + installation.InstallationId == deviceId.ToString() && + installation.PushChannel == pushToken && + installation.Platform == NotificationPlatform.FcmV1 && + installation.Tags.Contains($"userId:{userId}") && + installation.Tags.Contains("clientType:Mobile") && + (identifierNull || installation.Tags.Contains($"deviceIdentifier:{identifier}")) && + (!partOfOrganizationId || installation.Tags.Contains($"organizationId:{organizationId}")) && + installation.Templates.Count == 3)); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => MatchingInstallationTemplate( + installation.Templates, "template:payload", + "{\"message\":{\"data\":{\"type\":\"$(type)\",\"payload\":\"$(payload)\"}}}", + new List + { + "template:payload", + $"template:payload_userId:{userId}", + "clientType:Mobile", + identifierNull ? null : $"template:payload_deviceIdentifier:{identifier}", + partOfOrganizationId ? $"organizationId:{organizationId}" : null, + }))); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => MatchingInstallationTemplate( + installation.Templates, "template:message", + "{\"message\":{\"data\":{\"type\":\"$(type)\"},\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}", + new List + { + "template:message", + $"template:message_userId:{userId}", + "clientType:Mobile", + identifierNull ? null : $"template:message_deviceIdentifier:{identifier}", + partOfOrganizationId ? $"organizationId:{organizationId}" : null, + }))); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => MatchingInstallationTemplate( + installation.Templates, "template:badgeMessage", + "{\"message\":{\"data\":{\"type\":\"$(type)\"},\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}", + new List + { + "template:badgeMessage", + $"template:badgeMessage_userId:{userId}", + "clientType:Mobile", + identifierNull ? null : $"template:badgeMessage_deviceIdentifier:{identifier}", + partOfOrganizationId ? $"organizationId:{organizationId}" : null, + }))); + } + + [Theory] + [BitAutoData(false, false)] + [BitAutoData(false, true)] + [BitAutoData(true, false)] + [BitAutoData(true, true)] + public async Task CreateOrUpdateRegistrationAsync_DeviceTypeIOS_InstallationCreated(bool identifierNull, + bool partOfOrganizationId, SutProvider sutProvider, Guid deviceId, + Guid userId, Guid identifier, Guid organizationId) + { + var notificationHubClient = Substitute.For(); + sutProvider.GetDependency().ClientFor(Arg.Any()).Returns(notificationHubClient); + + var pushToken = "test push token"; + + await sutProvider.Sut.CreateOrUpdateRegistrationAsync(pushToken, deviceId.ToString(), userId.ToString(), + identifierNull ? null : identifier.ToString(), DeviceType.iOS, + partOfOrganizationId ? [organizationId.ToString()] : []); + + sutProvider.GetDependency() + .Received(1) + .ClientFor(deviceId); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => + installation.InstallationId == deviceId.ToString() && + installation.PushChannel == pushToken && + installation.Platform == NotificationPlatform.Apns && + installation.Tags.Contains($"userId:{userId}") && + installation.Tags.Contains("clientType:Mobile") && + (identifierNull || installation.Tags.Contains($"deviceIdentifier:{identifier}")) && + (!partOfOrganizationId || installation.Tags.Contains($"organizationId:{organizationId}")) && + installation.Templates.Count == 3)); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => MatchingInstallationTemplate( + installation.Templates, "template:payload", + "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"},\"aps\":{\"content-available\":1}}", + new List + { + "template:payload", + $"template:payload_userId:{userId}", + "clientType:Mobile", + identifierNull ? null : $"template:payload_deviceIdentifier:{identifier}", + partOfOrganizationId ? $"organizationId:{organizationId}" : null, + }))); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => MatchingInstallationTemplate( + installation.Templates, "template:message", + "{\"data\":{\"type\":\"#(type)\"},\"aps\":{\"alert\":\"$(message)\",\"badge\":null,\"content-available\":1}}", + new List + { + "template:message", + $"template:message_userId:{userId}", + "clientType:Mobile", + identifierNull ? null : $"template:message_deviceIdentifier:{identifier}", + partOfOrganizationId ? $"organizationId:{organizationId}" : null, + }))); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => MatchingInstallationTemplate( + installation.Templates, "template:badgeMessage", + "{\"data\":{\"type\":\"#(type)\"},\"aps\":{\"alert\":\"$(message)\",\"badge\":\"#(badge)\",\"content-available\":1}}", + new List + { + "template:badgeMessage", + $"template:badgeMessage_userId:{userId}", + "clientType:Mobile", + identifierNull ? null : $"template:badgeMessage_deviceIdentifier:{identifier}", + partOfOrganizationId ? $"organizationId:{organizationId}" : null, + }))); + } + + [Theory] + [BitAutoData(false, false)] + [BitAutoData(false, true)] + [BitAutoData(true, false)] + [BitAutoData(true, true)] + public async Task CreateOrUpdateRegistrationAsync_DeviceTypeAndroidAmazon_InstallationCreated(bool identifierNull, + bool partOfOrganizationId, SutProvider sutProvider, Guid deviceId, + Guid userId, Guid identifier, Guid organizationId) + { + var notificationHubClient = Substitute.For(); + sutProvider.GetDependency().ClientFor(Arg.Any()).Returns(notificationHubClient); + + var pushToken = "test push token"; + + await sutProvider.Sut.CreateOrUpdateRegistrationAsync(pushToken, deviceId.ToString(), userId.ToString(), + identifierNull ? null : identifier.ToString(), DeviceType.AndroidAmazon, + partOfOrganizationId ? [organizationId.ToString()] : []); + + sutProvider.GetDependency() + .Received(1) + .ClientFor(deviceId); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => + installation.InstallationId == deviceId.ToString() && + installation.PushChannel == pushToken && + installation.Platform == NotificationPlatform.Adm && + installation.Tags.Contains($"userId:{userId}") && + installation.Tags.Contains("clientType:Mobile") && + (identifierNull || installation.Tags.Contains($"deviceIdentifier:{identifier}")) && + (!partOfOrganizationId || installation.Tags.Contains($"organizationId:{organizationId}")) && + installation.Templates.Count == 3)); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => MatchingInstallationTemplate( + installation.Templates, "template:payload", + "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}", + new List + { + "template:payload", + $"template:payload_userId:{userId}", + "clientType:Mobile", + identifierNull ? null : $"template:payload_deviceIdentifier:{identifier}", + partOfOrganizationId ? $"organizationId:{organizationId}" : null, + }))); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => MatchingInstallationTemplate( + installation.Templates, "template:message", + "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}", + new List + { + "template:message", + $"template:message_userId:{userId}", + "clientType:Mobile", + identifierNull ? null : $"template:message_deviceIdentifier:{identifier}", + partOfOrganizationId ? $"organizationId:{organizationId}" : null, + }))); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => MatchingInstallationTemplate( + installation.Templates, "template:badgeMessage", + "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}", + new List + { + "template:badgeMessage", + $"template:badgeMessage_userId:{userId}", + "clientType:Mobile", + identifierNull ? null : $"template:badgeMessage_deviceIdentifier:{identifier}", + partOfOrganizationId ? $"organizationId:{organizationId}" : null, + }))); + } + + [Theory] + [BitAutoData(DeviceType.ChromeBrowser)] + [BitAutoData(DeviceType.ChromeExtension)] + [BitAutoData(DeviceType.MacOsDesktop)] + public async Task CreateOrUpdateRegistrationAsync_DeviceTypeNotMobile_InstallationCreated(DeviceType deviceType, + SutProvider sutProvider, Guid deviceId, Guid userId, Guid identifier, + Guid organizationId) + { + var notificationHubClient = Substitute.For(); + sutProvider.GetDependency().ClientFor(Arg.Any()).Returns(notificationHubClient); + + var pushToken = "test push token"; + + await sutProvider.Sut.CreateOrUpdateRegistrationAsync(pushToken, deviceId.ToString(), userId.ToString(), + identifier.ToString(), deviceType, [organizationId.ToString()]); + + sutProvider.GetDependency() + .Received(1) + .ClientFor(deviceId); + await notificationHubClient + .Received(1) + .CreateOrUpdateInstallationAsync(Arg.Is(installation => + installation.InstallationId == deviceId.ToString() && + installation.PushChannel == pushToken && + installation.Tags.Contains($"userId:{userId}") && + installation.Tags.Contains($"clientType:{DeviceTypes.ToClientType(deviceType)}") && + installation.Tags.Contains($"deviceIdentifier:{identifier}") && + installation.Tags.Contains($"organizationId:{organizationId}") && + installation.Templates.Count == 0)); + } + + private static bool MatchingInstallationTemplate(IDictionary templates, string key, + string body, List tags) + { + var tagsNoNulls = tags.FindAll(tag => tag != null); + return templates.ContainsKey(key) && templates[key].Body == body && + templates[key].Tags.Count == tagsNoNulls.Count && + templates[key].Tags.All(tagsNoNulls.Contains); } } diff --git a/test/Core.Test/Platform/Push/Services/AzureQueuePushNotificationServiceTests.cs b/test/Core.Test/Platform/Push/Services/AzureQueuePushNotificationServiceTests.cs index 85ce5a79ac..7aa053ec6d 100644 --- a/test/Core.Test/Platform/Push/Services/AzureQueuePushNotificationServiceTests.cs +++ b/test/Core.Test/Platform/Push/Services/AzureQueuePushNotificationServiceTests.cs @@ -1,33 +1,66 @@ -using Bit.Core.Settings; +#nullable enable +using System.Text.Json; +using Azure.Storage.Queues; +using Bit.Core.Context; +using Bit.Core.Enums; +using Bit.Core.Models; +using Bit.Core.NotificationCenter.Entities; +using Bit.Core.Test.AutoFixture; +using Bit.Core.Test.AutoFixture.CurrentContextFixtures; +using Bit.Core.Test.NotificationCenter.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Http; using NSubstitute; using Xunit; namespace Bit.Core.Platform.Push.Internal.Test; +[QueueClientCustomize] +[SutProviderCustomize] public class AzureQueuePushNotificationServiceTests { - private readonly AzureQueuePushNotificationService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - - public AzureQueuePushNotificationServiceTests() + [Theory] + [BitAutoData] + [NotificationCustomize] + [CurrentContextCustomize] + public async void PushSyncNotificationAsync_Notification_Sent( + SutProvider sutProvider, Notification notification, Guid deviceIdentifier, + ICurrentContext currentContext) { - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); + currentContext.DeviceIdentifier.Returns(deviceIdentifier.ToString()); + sutProvider.GetDependency().HttpContext!.RequestServices + .GetService(Arg.Any()).Returns(currentContext); - _sut = new AzureQueuePushNotificationService( - _globalSettings, - _httpContextAccessor - ); + await sutProvider.Sut.PushSyncNotificationAsync(notification); + + await sutProvider.GetDependency().Received(1) + .SendMessageAsync(Arg.Is(message => + MatchMessage(PushType.SyncNotification, message, new SyncNotificationEquals(notification), + deviceIdentifier.ToString()))); } - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() + private static bool MatchMessage(PushType pushType, string message, IEquatable expectedPayloadEquatable, + string contextId) { - Assert.NotNull(_sut); + var pushNotificationData = + JsonSerializer.Deserialize>(message); + return pushNotificationData != null && + pushNotificationData.Type == pushType && + expectedPayloadEquatable.Equals(pushNotificationData.Payload) && + pushNotificationData.ContextId == contextId; + } + + private class SyncNotificationEquals(Notification notification) : IEquatable + { + public bool Equals(SyncNotificationPushNotification? other) + { + return other != null && + other.Id == notification.Id && + other.UserId == notification.UserId && + other.OrganizationId == notification.OrganizationId && + other.ClientType == notification.ClientType && + other.RevisionDate == notification.RevisionDate; + } } } diff --git a/test/Core.Test/Platform/Push/Services/MultiServicePushNotificationServiceTests.cs b/test/Core.Test/Platform/Push/Services/MultiServicePushNotificationServiceTests.cs index 021aa7f2cc..35997f80e9 100644 --- a/test/Core.Test/Platform/Push/Services/MultiServicePushNotificationServiceTests.cs +++ b/test/Core.Test/Platform/Push/Services/MultiServicePushNotificationServiceTests.cs @@ -1,44 +1,62 @@ -using AutoFixture; +#nullable enable +using Bit.Core.Enums; +using Bit.Core.NotificationCenter.Entities; +using Bit.Core.Test.NotificationCenter.AutoFixture; using Bit.Test.Common.AutoFixture; -using Microsoft.Extensions.Logging; +using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -using GlobalSettingsCustomization = Bit.Test.Common.AutoFixture.GlobalSettings; namespace Bit.Core.Platform.Push.Internal.Test; +[SutProviderCustomize] public class MultiServicePushNotificationServiceTests { - private readonly MultiServicePushNotificationService _sut; - - private readonly ILogger _logger; - private readonly ILogger _relayLogger; - private readonly ILogger _hubLogger; - private readonly IEnumerable _services; - private readonly Settings.GlobalSettings _globalSettings; - - public MultiServicePushNotificationServiceTests() + [Theory] + [BitAutoData] + [NotificationCustomize] + public async Task PushSyncNotificationAsync_Notification_Sent( + SutProvider sutProvider, Notification notification) { - _logger = Substitute.For>(); - _relayLogger = Substitute.For>(); - _hubLogger = Substitute.For>(); + await sutProvider.Sut.PushSyncNotificationAsync(notification); - var fixture = new Fixture().WithAutoNSubstitutions().Customize(new GlobalSettingsCustomization()); - _services = fixture.CreateMany(); - _globalSettings = fixture.Create(); - - _sut = new MultiServicePushNotificationService( - _services, - _logger, - _globalSettings - ); + await sutProvider.GetDependency>() + .First() + .Received(1) + .PushSyncNotificationAsync(notification); } - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() + [Theory] + [BitAutoData([null, null])] + [BitAutoData(ClientType.All, null)] + [BitAutoData([null, "test device id"])] + [BitAutoData(ClientType.All, "test device id")] + public async Task SendPayloadToUserAsync_Message_Sent(ClientType? clientType, string? deviceId, string userId, + PushType type, object payload, string identifier, SutProvider sutProvider) { - Assert.NotNull(_sut); + await sutProvider.Sut.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId, clientType); + + await sutProvider.GetDependency>() + .First() + .Received(1) + .SendPayloadToUserAsync(userId, type, payload, identifier, deviceId, clientType); + } + + [Theory] + [BitAutoData([null, null])] + [BitAutoData(ClientType.All, null)] + [BitAutoData([null, "test device id"])] + [BitAutoData(ClientType.All, "test device id")] + public async Task SendPayloadToOrganizationAsync_Message_Sent(ClientType? clientType, string? deviceId, + string organizationId, PushType type, object payload, string identifier, + SutProvider sutProvider) + { + await sutProvider.Sut.SendPayloadToOrganizationAsync(organizationId, type, payload, identifier, deviceId, + clientType); + + await sutProvider.GetDependency>() + .First() + .Received(1) + .SendPayloadToOrganizationAsync(organizationId, type, payload, identifier, deviceId, clientType); } } diff --git a/test/Core.Test/Services/DeviceServiceTests.cs b/test/Core.Test/Services/DeviceServiceTests.cs index 41ef0b4d74..98b04eb7d3 100644 --- a/test/Core.Test/Services/DeviceServiceTests.cs +++ b/test/Core.Test/Services/DeviceServiceTests.cs @@ -3,6 +3,7 @@ using Bit.Core.Auth.Models.Api.Request; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -16,15 +17,23 @@ namespace Bit.Core.Test.Services; [SutProviderCustomize] public class DeviceServiceTests { - [Fact] - public async Task DeviceSaveShouldUpdateRevisionDateAndPushRegistration() + [Theory] + [BitAutoData] + public async Task SaveAsync_IdProvided_UpdatedRevisionDateAndPushRegistration(Guid id, Guid userId, + Guid organizationId1, Guid organizationId2, + OrganizationUserOrganizationDetails organizationUserOrganizationDetails1, + OrganizationUserOrganizationDetails organizationUserOrganizationDetails2) { + organizationUserOrganizationDetails1.OrganizationId = organizationId1; + organizationUserOrganizationDetails2.OrganizationId = organizationId2; + var deviceRepo = Substitute.For(); var pushRepo = Substitute.For(); - var deviceService = new DeviceService(deviceRepo, pushRepo); + var organizationUserRepository = Substitute.For(); + organizationUserRepository.GetManyDetailsByUserAsync(Arg.Any(), Arg.Any()) + .Returns([organizationUserOrganizationDetails1, organizationUserOrganizationDetails2]); + var deviceService = new DeviceService(deviceRepo, pushRepo, organizationUserRepository); - var id = Guid.NewGuid(); - var userId = Guid.NewGuid(); var device = new Device { Id = id, @@ -37,8 +46,53 @@ public class DeviceServiceTests await deviceService.SaveAsync(device); Assert.True(device.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1)); - await pushRepo.Received().CreateOrUpdateRegistrationAsync("testtoken", id.ToString(), - userId.ToString(), "testid", DeviceType.Android); + await pushRepo.Received(1).CreateOrUpdateRegistrationAsync("testtoken", id.ToString(), + userId.ToString(), "testid", DeviceType.Android, + Arg.Do>(organizationIds => + { + var organizationIdsList = organizationIds.ToList(); + Assert.Equal(2, organizationIdsList.Count); + Assert.Contains(organizationId1.ToString(), organizationIdsList); + Assert.Contains(organizationId2.ToString(), organizationIdsList); + })); + } + + [Theory] + [BitAutoData] + public async Task SaveAsync_IdNotProvided_CreatedAndPushRegistration(Guid userId, Guid organizationId1, + Guid organizationId2, + OrganizationUserOrganizationDetails organizationUserOrganizationDetails1, + OrganizationUserOrganizationDetails organizationUserOrganizationDetails2) + { + organizationUserOrganizationDetails1.OrganizationId = organizationId1; + organizationUserOrganizationDetails2.OrganizationId = organizationId2; + + var deviceRepo = Substitute.For(); + var pushRepo = Substitute.For(); + var organizationUserRepository = Substitute.For(); + organizationUserRepository.GetManyDetailsByUserAsync(Arg.Any(), Arg.Any()) + .Returns([organizationUserOrganizationDetails1, organizationUserOrganizationDetails2]); + var deviceService = new DeviceService(deviceRepo, pushRepo, organizationUserRepository); + + var device = new Device + { + Name = "test device", + Type = DeviceType.Android, + UserId = userId, + PushToken = "testtoken", + Identifier = "testid" + }; + await deviceService.SaveAsync(device); + + await pushRepo.Received(1).CreateOrUpdateRegistrationAsync("testtoken", + Arg.Do(id => Guid.TryParse(id, out var _)), userId.ToString(), "testid", DeviceType.Android, + Arg.Do>(organizationIds => + { + var organizationIdsList = organizationIds.ToList(); + Assert.Equal(2, organizationIdsList.Count); + Assert.Contains(organizationId1.ToString(), organizationIdsList); + Assert.Contains(organizationId2.ToString(), organizationIdsList); + })); } /// @@ -62,12 +116,7 @@ public class DeviceServiceTests sutProvider.GetDependency() .GetManyByUserIdAsync(currentUserId) - .Returns(new List - { - deviceOne, - deviceTwo, - deviceThree, - }); + .Returns(new List { deviceOne, deviceTwo, deviceThree, }); var currentDeviceModel = new DeviceKeysUpdateRequestModel { @@ -85,7 +134,8 @@ public class DeviceServiceTests }, }; - await sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel, alteredDeviceModels); + await sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel, + alteredDeviceModels); // Updating trust, "current" or "other" only needs to change the EncryptedPublicKey & EncryptedUserKey await sutProvider.GetDependency() @@ -149,11 +199,7 @@ public class DeviceServiceTests sutProvider.GetDependency() .GetManyByUserIdAsync(currentUserId) - .Returns(new List - { - deviceOne, - deviceTwo, - }); + .Returns(new List { deviceOne, deviceTwo, }); var currentDeviceModel = new DeviceKeysUpdateRequestModel { @@ -171,7 +217,8 @@ public class DeviceServiceTests }, }; - await sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel, alteredDeviceModels); + await sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel, + alteredDeviceModels); // Check that UpsertAsync was called for the trusted device await sutProvider.GetDependency() @@ -203,11 +250,7 @@ public class DeviceServiceTests sutProvider.GetDependency() .GetManyByUserIdAsync(currentUserId) - .Returns(new List - { - deviceOne, - deviceTwo, - }); + .Returns(new List { deviceOne, deviceTwo, }); var currentDeviceModel = new DeviceKeysUpdateRequestModel { @@ -237,11 +280,7 @@ public class DeviceServiceTests sutProvider.GetDependency() .GetManyByUserIdAsync(currentUserId) - .Returns(new List - { - deviceOne, - deviceTwo, - }); + .Returns(new List { deviceOne, deviceTwo, }); var currentDeviceModel = new DeviceKeysUpdateRequestModel { @@ -260,6 +299,7 @@ public class DeviceServiceTests }; await Assert.ThrowsAsync(() => - sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel, alteredDeviceModels)); + sutProvider.Sut.UpdateDevicesTrustAsync("current_device", currentUserId, currentDeviceModel, + alteredDeviceModels)); } } diff --git a/test/Identity.IntegrationTest/openid-configuration.json b/test/Identity.IntegrationTest/openid-configuration.json index 23e5a67c06..4d74f66009 100644 --- a/test/Identity.IntegrationTest/openid-configuration.json +++ b/test/Identity.IntegrationTest/openid-configuration.json @@ -24,6 +24,7 @@ "sstamp", "premium", "device", + "devicetype", "orgowner", "orgadmin", "orguser",