From 6296c1fb1fbe06c886202800b29f9cc877317eda Mon Sep 17 00:00:00 2001 From: Maciej Zieniuk Date: Tue, 22 Oct 2024 11:11:55 +0100 Subject: [PATCH] PM-10600: Sending to specific client types for other clients --- src/Core/Models/PushNotification.cs | 1 + .../AzureQueuePushNotificationService.cs | 1 + .../NotificationHubPushNotificationService.cs | 1 + ...NotificationsApiPushNotificationService.cs | 1 + .../RelayPushNotificationService.cs | 1 + src/Notifications/HubHelpers.cs | 50 +++++++++++----- src/Notifications/NotificationsHub.cs | 57 ++++++++++++++++++- 7 files changed, 96 insertions(+), 16 deletions(-) diff --git a/src/Core/Models/PushNotification.cs b/src/Core/Models/PushNotification.cs index e27bea364c..cf7fc850bb 100644 --- a/src/Core/Models/PushNotification.cs +++ b/src/Core/Models/PushNotification.cs @@ -51,6 +51,7 @@ public class SyncNotificationPushNotification public bool Global { get; set; } public Guid? UserId { get; set; } public Guid? OrganizationId { get; set; } + public ClientType ClientType { get; set; } public DateTime RevisionDate { get; set; } } diff --git a/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs b/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs index b5c6677064..3349aada61 100644 --- a/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs +++ b/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs @@ -174,6 +174,7 @@ public class AzureQueuePushNotificationService : IPushNotificationService Global = notification.Global, UserId = notification.Id, OrganizationId = notification.Id, + ClientType = notification.ClientType, RevisionDate = notification.RevisionDate }; diff --git a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs b/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs index 8e4ca97674..b302d28010 100644 --- a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs +++ b/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs @@ -202,6 +202,7 @@ public class NotificationHubPushNotificationService : IPushNotificationService Global = notification.Global, UserId = notification.Id, OrganizationId = notification.Id, + ClientType = notification.ClientType, RevisionDate = notification.RevisionDate }; diff --git a/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs b/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs index a43720c4fa..bde4348534 100644 --- a/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs +++ b/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs @@ -181,6 +181,7 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService Global = notification.Global, UserId = notification.Id, OrganizationId = notification.Id, + ClientType = notification.ClientType, RevisionDate = notification.RevisionDate }; diff --git a/src/Core/Services/Implementations/RelayPushNotificationService.cs b/src/Core/Services/Implementations/RelayPushNotificationService.cs index 16db74acf6..6e5fd5d42f 100644 --- a/src/Core/Services/Implementations/RelayPushNotificationService.cs +++ b/src/Core/Services/Implementations/RelayPushNotificationService.cs @@ -197,6 +197,7 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti Global = notification.Global, UserId = notification.Id, OrganizationId = notification.Id, + ClientType = notification.ClientType, RevisionDate = notification.RevisionDate }; diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index 8736643e2b..a653a7c3ba 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -10,6 +10,8 @@ public static class HubHelpers private static JsonSerializerOptions _deserializerOptions = new JsonSerializerOptions { PropertyNameCaseInsensitive = true }; + private static readonly string _receiveMessageMethod = "ReceiveMessage"; + public static async Task SendNotificationToHubAsync( string notificationJson, IHubContext hubContext, @@ -33,13 +35,13 @@ public static class HubHelpers if (cipherNotification.Payload.UserId.HasValue) { await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); + .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } else if (cipherNotification.Payload.OrganizationId.HasValue) { - await hubContext.Clients.Group( - $"Organization_{cipherNotification.Payload.OrganizationId}") - .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); + await hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(cipherNotification.Payload.OrganizationId.Value)) + .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } break; @@ -50,7 +52,7 @@ public static class HubHelpers JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", folderNotification, cancellationToken); + .SendAsync(_receiveMessageMethod, folderNotification, cancellationToken); break; case PushType.SyncCiphers: case PushType.SyncVault: @@ -62,7 +64,7 @@ public static class HubHelpers JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", userNotification, cancellationToken); + .SendAsync(_receiveMessageMethod, userNotification, cancellationToken); break; case PushType.SyncSendCreate: case PushType.SyncSendUpdate: @@ -71,7 +73,7 @@ public static class HubHelpers JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", sendNotification, cancellationToken); + .SendAsync(_receiveMessageMethod, sendNotification, cancellationToken); break; case PushType.AuthRequestResponse: var authRequestResponseNotification = @@ -85,7 +87,7 @@ public static class HubHelpers JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", authRequestNotification, cancellationToken); + .SendAsync(_receiveMessageMethod, authRequestNotification, cancellationToken); break; case PushType.SyncNotification: var syncNotification = @@ -93,19 +95,39 @@ public static class HubHelpers notificationJson, _deserializerOptions); if (syncNotification.Payload.Global) { - await hubContext.Clients.All.SendAsync("ReceiveMessage", syncNotification, cancellationToken); + if (syncNotification.Payload.ClientType == ClientType.All) + { + await hubContext.Clients.All.SendAsync(_receiveMessageMethod, syncNotification, + cancellationToken); + } + else + { + await hubContext.Clients + .Group(NotificationsHub.GetGlobalGroup(syncNotification.Payload.ClientType)) + .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken); + } } else if (syncNotification.Payload.UserId.HasValue) { - await hubContext.Clients.User(syncNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", syncNotification, cancellationToken); + if (syncNotification.Payload.ClientType == ClientType.All) + { + await hubContext.Clients.User(syncNotification.Payload.UserId.ToString()) + .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken); + } + else + { + await hubContext.Clients.Group(NotificationsHub.GetUserGroup( + syncNotification.Payload.UserId.Value, syncNotification.Payload.ClientType)) + .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken); + } } else if (syncNotification.Payload.OrganizationId.HasValue) { - await hubContext.Clients.Group( - $"Organization_{syncNotification.Payload.OrganizationId}") - .SendAsync("ReceiveMessage", syncNotification, cancellationToken); + await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup( + syncNotification.Payload.OrganizationId.Value, syncNotification.Payload.ClientType)) + .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken); } + break; default: break; diff --git a/src/Notifications/NotificationsHub.cs b/src/Notifications/NotificationsHub.cs index a86cf329c5..2cbe48d11f 100644 --- a/src/Notifications/NotificationsHub.cs +++ b/src/Notifications/NotificationsHub.cs @@ -1,5 +1,7 @@ using Bit.Core.Context; +using Bit.Core.Enums; using Bit.Core.Settings; +using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; namespace Bit.Notifications; @@ -20,13 +22,30 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub { var currentContext = new CurrentContext(null, null); await currentContext.BuildAsync(Context.User, _globalSettings); + + var clientType = DeviceTypes.ToClientType(currentContext.DeviceType); + if (clientType != ClientType.All) + { + await Groups.AddToGroupAsync(Context.ConnectionId, GetGlobalGroup(clientType)); + if (currentContext.UserId.HasValue) + { + await Groups.AddToGroupAsync(Context.ConnectionId, + GetUserGroup(currentContext.UserId.Value, clientType)); + } + } + if (currentContext.Organizations != null) { foreach (var org in currentContext.Organizations) { - await Groups.AddToGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); + await Groups.AddToGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id)); + if (clientType != ClientType.All) + { + await Groups.AddToGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id, clientType)); + } } } + _connectionCounter.Increment(); await base.OnConnectedAsync(); } @@ -35,14 +54,48 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub { var currentContext = new CurrentContext(null, null); await currentContext.BuildAsync(Context.User, _globalSettings); + + var clientType = DeviceTypes.ToClientType(currentContext.DeviceType); + if (clientType != ClientType.All) + { + await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetGlobalGroup(clientType)); + if (currentContext.UserId.HasValue) + { + await Groups.RemoveFromGroupAsync(Context.ConnectionId, + GetUserGroup(currentContext.UserId.Value, clientType)); + } + } + if (currentContext.Organizations != null) { foreach (var org in currentContext.Organizations) { - await Groups.RemoveFromGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); + await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id)); + if (clientType != ClientType.All) + { + await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id, clientType)); + } } } + _connectionCounter.Decrement(); await base.OnDisconnectedAsync(exception); } + + public static string GetGlobalGroup(ClientType clientType) + { + return $"ClientType_{clientType}"; + } + + public static string GetUserGroup(Guid userId, ClientType clientType) + { + return $"{userId}_{clientType}"; + } + + public static string GetOrganizationGroup(Guid organizationId, ClientType? clientType = null) + { + return clientType is not ClientType.All + ? $"Organization_{organizationId}" + : $"Organization_{organizationId}_{clientType}"; + } }