1
0
mirror of https://github.com/bitwarden/server.git synced 2025-07-01 16:12:49 -05:00

PM-10600: Sending to specific client types for other clients

This commit is contained in:
Maciej Zieniuk
2024-10-22 11:11:55 +01:00
parent 7020565770
commit 6296c1fb1f
7 changed files with 96 additions and 16 deletions

View File

@ -51,6 +51,7 @@ public class SyncNotificationPushNotification
public bool Global { get; set; } public bool Global { get; set; }
public Guid? UserId { get; set; } public Guid? UserId { get; set; }
public Guid? OrganizationId { get; set; } public Guid? OrganizationId { get; set; }
public ClientType ClientType { get; set; }
public DateTime RevisionDate { get; set; } public DateTime RevisionDate { get; set; }
} }

View File

@ -174,6 +174,7 @@ public class AzureQueuePushNotificationService : IPushNotificationService
Global = notification.Global, Global = notification.Global,
UserId = notification.Id, UserId = notification.Id,
OrganizationId = notification.Id, OrganizationId = notification.Id,
ClientType = notification.ClientType,
RevisionDate = notification.RevisionDate RevisionDate = notification.RevisionDate
}; };

View File

@ -202,6 +202,7 @@ public class NotificationHubPushNotificationService : IPushNotificationService
Global = notification.Global, Global = notification.Global,
UserId = notification.Id, UserId = notification.Id,
OrganizationId = notification.Id, OrganizationId = notification.Id,
ClientType = notification.ClientType,
RevisionDate = notification.RevisionDate RevisionDate = notification.RevisionDate
}; };

View File

@ -181,6 +181,7 @@ public class NotificationsApiPushNotificationService : BaseIdentityClientService
Global = notification.Global, Global = notification.Global,
UserId = notification.Id, UserId = notification.Id,
OrganizationId = notification.Id, OrganizationId = notification.Id,
ClientType = notification.ClientType,
RevisionDate = notification.RevisionDate RevisionDate = notification.RevisionDate
}; };

View File

@ -197,6 +197,7 @@ public class RelayPushNotificationService : BaseIdentityClientService, IPushNoti
Global = notification.Global, Global = notification.Global,
UserId = notification.Id, UserId = notification.Id,
OrganizationId = notification.Id, OrganizationId = notification.Id,
ClientType = notification.ClientType,
RevisionDate = notification.RevisionDate RevisionDate = notification.RevisionDate
}; };

View File

@ -10,6 +10,8 @@ public static class HubHelpers
private static JsonSerializerOptions _deserializerOptions = private static JsonSerializerOptions _deserializerOptions =
new JsonSerializerOptions { PropertyNameCaseInsensitive = true }; new JsonSerializerOptions { PropertyNameCaseInsensitive = true };
private static readonly string _receiveMessageMethod = "ReceiveMessage";
public static async Task SendNotificationToHubAsync( public static async Task SendNotificationToHubAsync(
string notificationJson, string notificationJson,
IHubContext<NotificationsHub> hubContext, IHubContext<NotificationsHub> hubContext,
@ -33,13 +35,13 @@ public static class HubHelpers
if (cipherNotification.Payload.UserId.HasValue) if (cipherNotification.Payload.UserId.HasValue)
{ {
await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", cipherNotification, cancellationToken); .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken);
} }
else if (cipherNotification.Payload.OrganizationId.HasValue) else if (cipherNotification.Payload.OrganizationId.HasValue)
{ {
await hubContext.Clients.Group( await hubContext.Clients
$"Organization_{cipherNotification.Payload.OrganizationId}") .Group(NotificationsHub.GetOrganizationGroup(cipherNotification.Payload.OrganizationId.Value))
.SendAsync("ReceiveMessage", cipherNotification, cancellationToken); .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken);
} }
break; break;
@ -50,7 +52,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<SyncFolderPushNotification>>( JsonSerializer.Deserialize<PushNotificationData<SyncFolderPushNotification>>(
notificationJson, _deserializerOptions); notificationJson, _deserializerOptions);
await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) await hubContext.Clients.User(folderNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", folderNotification, cancellationToken); .SendAsync(_receiveMessageMethod, folderNotification, cancellationToken);
break; break;
case PushType.SyncCiphers: case PushType.SyncCiphers:
case PushType.SyncVault: case PushType.SyncVault:
@ -62,7 +64,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<UserPushNotification>>( JsonSerializer.Deserialize<PushNotificationData<UserPushNotification>>(
notificationJson, _deserializerOptions); notificationJson, _deserializerOptions);
await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) await hubContext.Clients.User(userNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", userNotification, cancellationToken); .SendAsync(_receiveMessageMethod, userNotification, cancellationToken);
break; break;
case PushType.SyncSendCreate: case PushType.SyncSendCreate:
case PushType.SyncSendUpdate: case PushType.SyncSendUpdate:
@ -71,7 +73,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<SyncSendPushNotification>>( JsonSerializer.Deserialize<PushNotificationData<SyncSendPushNotification>>(
notificationJson, _deserializerOptions); notificationJson, _deserializerOptions);
await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) await hubContext.Clients.User(sendNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", sendNotification, cancellationToken); .SendAsync(_receiveMessageMethod, sendNotification, cancellationToken);
break; break;
case PushType.AuthRequestResponse: case PushType.AuthRequestResponse:
var authRequestResponseNotification = var authRequestResponseNotification =
@ -85,7 +87,7 @@ public static class HubHelpers
JsonSerializer.Deserialize<PushNotificationData<AuthRequestPushNotification>>( JsonSerializer.Deserialize<PushNotificationData<AuthRequestPushNotification>>(
notificationJson, _deserializerOptions); notificationJson, _deserializerOptions);
await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", authRequestNotification, cancellationToken); .SendAsync(_receiveMessageMethod, authRequestNotification, cancellationToken);
break; break;
case PushType.SyncNotification: case PushType.SyncNotification:
var syncNotification = var syncNotification =
@ -93,19 +95,39 @@ public static class HubHelpers
notificationJson, _deserializerOptions); notificationJson, _deserializerOptions);
if (syncNotification.Payload.Global) if (syncNotification.Payload.Global)
{ {
await hubContext.Clients.All.SendAsync("ReceiveMessage", syncNotification, cancellationToken); if (syncNotification.Payload.ClientType == ClientType.All)
{
await hubContext.Clients.All.SendAsync(_receiveMessageMethod, syncNotification,
cancellationToken);
}
else
{
await hubContext.Clients
.Group(NotificationsHub.GetGlobalGroup(syncNotification.Payload.ClientType))
.SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
}
} }
else if (syncNotification.Payload.UserId.HasValue) else if (syncNotification.Payload.UserId.HasValue)
{
if (syncNotification.Payload.ClientType == ClientType.All)
{ {
await hubContext.Clients.User(syncNotification.Payload.UserId.ToString()) await hubContext.Clients.User(syncNotification.Payload.UserId.ToString())
.SendAsync("ReceiveMessage", syncNotification, cancellationToken); .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
}
else
{
await hubContext.Clients.Group(NotificationsHub.GetUserGroup(
syncNotification.Payload.UserId.Value, syncNotification.Payload.ClientType))
.SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
}
} }
else if (syncNotification.Payload.OrganizationId.HasValue) else if (syncNotification.Payload.OrganizationId.HasValue)
{ {
await hubContext.Clients.Group( await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(
$"Organization_{syncNotification.Payload.OrganizationId}") syncNotification.Payload.OrganizationId.Value, syncNotification.Payload.ClientType))
.SendAsync("ReceiveMessage", syncNotification, cancellationToken); .SendAsync(_receiveMessageMethod, syncNotification, cancellationToken);
} }
break; break;
default: default:
break; break;

View File

@ -1,5 +1,7 @@
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Enums;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
namespace Bit.Notifications; namespace Bit.Notifications;
@ -20,13 +22,30 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
{ {
var currentContext = new CurrentContext(null, null); var currentContext = new CurrentContext(null, null);
await currentContext.BuildAsync(Context.User, _globalSettings); await currentContext.BuildAsync(Context.User, _globalSettings);
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
if (clientType != ClientType.All)
{
await Groups.AddToGroupAsync(Context.ConnectionId, GetGlobalGroup(clientType));
if (currentContext.UserId.HasValue)
{
await Groups.AddToGroupAsync(Context.ConnectionId,
GetUserGroup(currentContext.UserId.Value, clientType));
}
}
if (currentContext.Organizations != null) if (currentContext.Organizations != null)
{ {
foreach (var org in currentContext.Organizations) foreach (var org in currentContext.Organizations)
{ {
await Groups.AddToGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); await Groups.AddToGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id));
if (clientType != ClientType.All)
{
await Groups.AddToGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id, clientType));
} }
} }
}
_connectionCounter.Increment(); _connectionCounter.Increment();
await base.OnConnectedAsync(); await base.OnConnectedAsync();
} }
@ -35,14 +54,48 @@ public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub
{ {
var currentContext = new CurrentContext(null, null); var currentContext = new CurrentContext(null, null);
await currentContext.BuildAsync(Context.User, _globalSettings); await currentContext.BuildAsync(Context.User, _globalSettings);
var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
if (clientType != ClientType.All)
{
await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetGlobalGroup(clientType));
if (currentContext.UserId.HasValue)
{
await Groups.RemoveFromGroupAsync(Context.ConnectionId,
GetUserGroup(currentContext.UserId.Value, clientType));
}
}
if (currentContext.Organizations != null) if (currentContext.Organizations != null)
{ {
foreach (var org in currentContext.Organizations) foreach (var org in currentContext.Organizations)
{ {
await Groups.RemoveFromGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id));
if (clientType != ClientType.All)
{
await Groups.RemoveFromGroupAsync(Context.ConnectionId, GetOrganizationGroup(org.Id, clientType));
} }
} }
}
_connectionCounter.Decrement(); _connectionCounter.Decrement();
await base.OnDisconnectedAsync(exception); await base.OnDisconnectedAsync(exception);
} }
public static string GetGlobalGroup(ClientType clientType)
{
return $"ClientType_{clientType}";
}
public static string GetUserGroup(Guid userId, ClientType clientType)
{
return $"{userId}_{clientType}";
}
public static string GetOrganizationGroup(Guid organizationId, ClientType? clientType = null)
{
return clientType is not ClientType.All
? $"Organization_{organizationId}"
: $"Organization_{organizationId}_{clientType}";
}
} }