1
0
mirror of https://github.com/bitwarden/server.git synced 2025-04-04 20:50:21 -05:00

PM-11123: Device Type mapping (#4768)

* PM-11123: Device Type mapping

* PM-11123: Moving ClientType out of NotificationCenter, naming clash with Identity ClientType

* PM-11123: Rename ClientType in ICurrentContext to match the type
This commit is contained in:
Maciej Zieniuk 2024-09-23 23:02:32 +02:00 committed by GitHub
parent e1bf8a9206
commit 9a5c6fe527
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 89 additions and 52 deletions

View File

@ -28,16 +28,16 @@ public class CreateProjectCommand : ICreateProjectCommand
_currentContext = currentContext; _currentContext = currentContext;
} }
public async Task<Project> CreateAsync(Project project, Guid id, ClientType clientType) public async Task<Project> CreateAsync(Project project, Guid id, IdentityClientType identityClientType)
{ {
if (clientType != ClientType.User && clientType != ClientType.ServiceAccount) if (identityClientType != IdentityClientType.User && identityClientType != IdentityClientType.ServiceAccount)
{ {
throw new NotFoundException(); throw new NotFoundException();
} }
var createdProject = await _projectRepository.CreateAsync(project); var createdProject = await _projectRepository.CreateAsync(project);
if (clientType == ClientType.User) if (identityClientType == IdentityClientType.User)
{ {
var orgUser = await _organizationUserRepository.GetByOrganizationAsync(createdProject.OrganizationId, id); var orgUser = await _organizationUserRepository.GetByOrganizationAsync(createdProject.OrganizationId, id);
@ -52,7 +52,7 @@ public class CreateProjectCommand : ICreateProjectCommand
await _accessPolicyRepository.CreateManyAsync(new List<BaseAccessPolicy> { accessPolicy }); await _accessPolicyRepository.CreateManyAsync(new List<BaseAccessPolicy> { accessPolicy });
} }
else if (clientType == ClientType.ServiceAccount) else if (identityClientType == IdentityClientType.ServiceAccount)
{ {
var serviceAccountProjectAccessPolicy = new ServiceAccountProjectAccessPolicy() var serviceAccountProjectAccessPolicy = new ServiceAccountProjectAccessPolicy()
{ {

View File

@ -21,7 +21,7 @@ public class AccessClientQuery : IAccessClientQuery
ClaimsPrincipal claimsPrincipal, Guid organizationId) ClaimsPrincipal claimsPrincipal, Guid organizationId)
{ {
var orgAdmin = await _currentContext.OrganizationAdmin(organizationId); var orgAdmin = await _currentContext.OrganizationAdmin(organizationId);
var accessClient = AccessClientHelper.ToAccessClient(_currentContext.ClientType, orgAdmin); var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin);
var userId = _userService.GetProperUserId(claimsPrincipal).Value; var userId = _userService.GetProperUserId(claimsPrincipal).Value;
return (accessClient, userId); return (accessClient, userId);
} }

View File

@ -30,7 +30,7 @@ public class CreateProjectCommandTests
.CreateAsync(Arg.Any<Project>()) .CreateAsync(Arg.Any<Project>())
.Returns(data); .Returns(data);
await sutProvider.Sut.CreateAsync(data, userId, sutProvider.GetDependency<ICurrentContext>().ClientType); await sutProvider.Sut.CreateAsync(data, userId, sutProvider.GetDependency<ICurrentContext>().IdentityClientType);
await sutProvider.GetDependency<IProjectRepository>().Received(1) await sutProvider.GetDependency<IProjectRepository>().Received(1)
.CreateAsync(Arg.Is(data)); .CreateAsync(Arg.Is(data));

View File

@ -57,7 +57,7 @@ public class ProjectsController : Controller
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var orgAdmin = await _currentContext.OrganizationAdmin(organizationId); var orgAdmin = await _currentContext.OrganizationAdmin(organizationId);
var accessClient = AccessClientHelper.ToAccessClient(_currentContext.ClientType, orgAdmin); var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin);
var projects = await _projectRepository.GetManyByOrganizationIdAsync(organizationId, userId, accessClient); var projects = await _projectRepository.GetManyByOrganizationIdAsync(organizationId, userId, accessClient);
@ -84,7 +84,7 @@ public class ProjectsController : Controller
} }
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var result = await _createProjectCommand.CreateAsync(project, userId, _currentContext.ClientType); var result = await _createProjectCommand.CreateAsync(project, userId, _currentContext.IdentityClientType);
// Creating a project means you have read & write permission. // Creating a project means you have read & write permission.
return new ProjectResponseModel(result, true, true); return new ProjectResponseModel(result, true, true);
@ -124,7 +124,7 @@ public class ProjectsController : Controller
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var orgAdmin = await _currentContext.OrganizationAdmin(project.OrganizationId); var orgAdmin = await _currentContext.OrganizationAdmin(project.OrganizationId);
var accessClient = AccessClientHelper.ToAccessClient(_currentContext.ClientType, orgAdmin); var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin);
var access = await _projectRepository.AccessToProjectAsync(id, userId, accessClient); var access = await _projectRepository.AccessToProjectAsync(id, userId, accessClient);

View File

@ -85,7 +85,7 @@ public class SecretsController : Controller
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var orgAdmin = await _currentContext.OrganizationAdmin(organizationId); var orgAdmin = await _currentContext.OrganizationAdmin(organizationId);
var accessClient = AccessClientHelper.ToAccessClient(_currentContext.ClientType, orgAdmin); var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin);
var secrets = await _secretRepository.GetManyDetailsByOrganizationIdAsync(organizationId, userId, accessClient); var secrets = await _secretRepository.GetManyDetailsByOrganizationIdAsync(organizationId, userId, accessClient);
@ -136,7 +136,7 @@ public class SecretsController : Controller
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var orgAdmin = await _currentContext.OrganizationAdmin(secret.OrganizationId); var orgAdmin = await _currentContext.OrganizationAdmin(secret.OrganizationId);
var accessClient = AccessClientHelper.ToAccessClient(_currentContext.ClientType, orgAdmin); var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin);
var access = await _secretRepository.AccessToSecretAsync(id, userId, accessClient); var access = await _secretRepository.AccessToSecretAsync(id, userId, accessClient);
@ -145,7 +145,7 @@ public class SecretsController : Controller
throw new NotFoundException(); throw new NotFoundException();
} }
if (_currentContext.ClientType == ClientType.ServiceAccount) if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount)
{ {
await _eventService.LogServiceAccountSecretEventAsync(userId, secret, EventType.Secret_Retrieved); await _eventService.LogServiceAccountSecretEventAsync(userId, secret, EventType.Secret_Retrieved);
@ -167,7 +167,7 @@ public class SecretsController : Controller
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var orgAdmin = await _currentContext.OrganizationAdmin(project.OrganizationId); var orgAdmin = await _currentContext.OrganizationAdmin(project.OrganizationId);
var accessClient = AccessClientHelper.ToAccessClient(_currentContext.ClientType, orgAdmin); var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin);
var secrets = await _secretRepository.GetManyDetailsByProjectIdAsync(projectId, userId, accessClient); var secrets = await _secretRepository.GetManyDetailsByProjectIdAsync(projectId, userId, accessClient);
@ -311,7 +311,7 @@ public class SecretsController : Controller
private async Task LogSecretsRetrievalAsync(Guid organizationId, IEnumerable<Secret> secrets) private async Task LogSecretsRetrievalAsync(Guid organizationId, IEnumerable<Secret> secrets)
{ {
if (_currentContext.ClientType == ClientType.ServiceAccount) if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount)
{ {
var userId = _userService.GetProperUserId(User)!.Value; var userId = _userService.GetProperUserId(User)!.Value;
var org = await _organizationRepository.GetByIdAsync(organizationId); var org = await _organizationRepository.GetByIdAsync(organizationId);

View File

@ -81,7 +81,7 @@ public class ServiceAccountsController : Controller
var userId = _userService.GetProperUserId(User).Value; var userId = _userService.GetProperUserId(User).Value;
var orgAdmin = await _currentContext.OrganizationAdmin(organizationId); var orgAdmin = await _currentContext.OrganizationAdmin(organizationId);
var accessClient = AccessClientHelper.ToAccessClient(_currentContext.ClientType, orgAdmin); var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin);
var results = var results =
await _serviceAccountSecretsDetailsQuery.GetManyByOrganizationIdAsync(organizationId, userId, accessClient, await _serviceAccountSecretsDetailsQuery.GetManyByOrganizationIdAsync(organizationId, userId, accessClient,

View File

@ -39,7 +39,7 @@ public class CurrentContext : ICurrentContext
public virtual int? BotScore { get; set; } public virtual int? BotScore { get; set; }
public virtual string ClientId { get; set; } public virtual string ClientId { get; set; }
public virtual Version ClientVersion { get; set; } public virtual Version ClientVersion { get; set; }
public virtual ClientType ClientType { get; set; } public virtual IdentityClientType IdentityClientType { get; set; }
public virtual Guid? ServiceAccountOrganizationId { get; set; } public virtual Guid? ServiceAccountOrganizationId { get; set; }
public CurrentContext( public CurrentContext(
@ -151,11 +151,11 @@ public class CurrentContext : ICurrentContext
var clientType = GetClaimValue(claimsDict, Claims.Type); var clientType = GetClaimValue(claimsDict, Claims.Type);
if (clientType != null) if (clientType != null)
{ {
Enum.TryParse(clientType, out ClientType c); Enum.TryParse(clientType, out IdentityClientType c);
ClientType = c; IdentityClientType = c;
} }
if (ClientType == ClientType.ServiceAccount) if (IdentityClientType == IdentityClientType.ServiceAccount)
{ {
ServiceAccountOrganizationId = new Guid(GetClaimValue(claimsDict, Claims.Organization)); ServiceAccountOrganizationId = new Guid(GetClaimValue(claimsDict, Claims.Organization));
} }

View File

@ -23,7 +23,7 @@ public interface ICurrentContext
List<CurrentContextOrganization> Organizations { get; set; } List<CurrentContextOrganization> Organizations { get; set; }
Guid? InstallationId { get; set; } Guid? InstallationId { get; set; }
Guid? OrganizationId { get; set; } Guid? OrganizationId { get; set; }
ClientType ClientType { get; set; } IdentityClientType IdentityClientType { get; set; }
bool IsBot { get; set; } bool IsBot { get; set; }
bool MaybeBot { get; set; } bool MaybeBot { get; set; }
int? BotScore { get; set; } int? BotScore { get; set; }

View File

@ -12,19 +12,19 @@ public enum AccessClientType
public static class AccessClientHelper public static class AccessClientHelper
{ {
public static AccessClientType ToAccessClient(ClientType clientType, bool bypassAccessCheck = false) public static AccessClientType ToAccessClient(IdentityClientType identityClientType, bool bypassAccessCheck = false)
{ {
if (bypassAccessCheck) if (bypassAccessCheck)
{ {
return AccessClientType.NoAccessCheck; return AccessClientType.NoAccessCheck;
} }
return clientType switch return identityClientType switch
{ {
ClientType.User => AccessClientType.User, IdentityClientType.User => AccessClientType.User,
ClientType.Organization => AccessClientType.Organization, IdentityClientType.Organization => AccessClientType.Organization,
ClientType.ServiceAccount => AccessClientType.ServiceAccount, IdentityClientType.ServiceAccount => AccessClientType.ServiceAccount,
_ => throw new ArgumentOutOfRangeException(nameof(clientType), clientType, null), _ => throw new ArgumentOutOfRangeException(nameof(identityClientType), identityClientType, null),
}; };
} }
} }

View File

@ -1,7 +1,7 @@
#nullable enable #nullable enable
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
namespace Bit.Core.NotificationCenter.Enums; namespace Bit.Core.Enums;
public enum ClientType : byte public enum ClientType : byte
{ {

View File

@ -1,6 +1,6 @@
namespace Bit.Core.Identity; namespace Bit.Core.Identity;
public enum ClientType : byte public enum IdentityClientType : byte
{ {
User = 0, User = 0,
Organization = 1, Organization = 1,

View File

@ -1,6 +1,7 @@
#nullable enable #nullable enable
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using Bit.Core.Entities; using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Enums; using Bit.Core.NotificationCenter.Enums;
using Bit.Core.Utilities; using Bit.Core.Utilities;

View File

@ -1,6 +1,6 @@
#nullable enable #nullable enable
using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities; using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Enums;
using Bit.Core.NotificationCenter.Models.Filter; using Bit.Core.NotificationCenter.Models.Filter;
using Bit.Core.Repositories; using Bit.Core.Repositories;

View File

@ -5,5 +5,5 @@ namespace Bit.Core.SecretsManager.Commands.Projects.Interfaces;
public interface ICreateProjectCommand public interface ICreateProjectCommand
{ {
Task<Project> CreateAsync(Project project, Guid userId, ClientType clientType); Task<Project> CreateAsync(Project project, Guid userId, IdentityClientType identityClientType);
} }

View File

@ -1,4 +1,5 @@
using Bit.Core.Context; using Bit.Core.Context;
using Bit.Core.Identity;
using Bit.Core.Settings; using Bit.Core.Settings;
using Bit.Core.Utilities; using Bit.Core.Utilities;
using LaunchDarkly.Logging; using LaunchDarkly.Logging;
@ -153,9 +154,9 @@ public class LaunchDarklyFeatureService : IFeatureService
var builder = LaunchDarkly.Sdk.Context.MultiBuilder(); var builder = LaunchDarkly.Sdk.Context.MultiBuilder();
switch (_currentContext.ClientType) switch (_currentContext.IdentityClientType)
{ {
case Identity.ClientType.User: case IdentityClientType.User:
{ {
ContextBuilder ldUser; ContextBuilder ldUser;
if (_currentContext.UserId.HasValue) if (_currentContext.UserId.HasValue)
@ -182,7 +183,7 @@ public class LaunchDarklyFeatureService : IFeatureService
} }
break; break;
case Identity.ClientType.Organization: case IdentityClientType.Organization:
{ {
if (_currentContext.OrganizationId.HasValue) if (_currentContext.OrganizationId.HasValue)
{ {
@ -196,7 +197,7 @@ public class LaunchDarklyFeatureService : IFeatureService
} }
break; break;
case Identity.ClientType.ServiceAccount: case IdentityClientType.ServiceAccount:
{ {
if (_currentContext.UserId.HasValue) if (_currentContext.UserId.HasValue)
{ {

View File

@ -4,21 +4,56 @@ namespace Bit.Core.Utilities;
public static class DeviceTypes public static class DeviceTypes
{ {
public static IReadOnlyCollection<DeviceType> MobileTypes { get; } = new[] public static IReadOnlyCollection<DeviceType> MobileTypes { get; } =
{ [
DeviceType.Android, DeviceType.Android,
DeviceType.iOS, DeviceType.iOS,
DeviceType.AndroidAmazon, DeviceType.AndroidAmazon
}; ];
public static IReadOnlyCollection<DeviceType> DesktopTypes { get; } = new[] public static IReadOnlyCollection<DeviceType> DesktopTypes { get; } =
{ [
DeviceType.LinuxDesktop, DeviceType.LinuxDesktop,
DeviceType.MacOsDesktop, DeviceType.MacOsDesktop,
DeviceType.WindowsDesktop, DeviceType.WindowsDesktop,
DeviceType.UWP, DeviceType.UWP,
DeviceType.WindowsCLI, DeviceType.WindowsCLI,
DeviceType.MacOsCLI, DeviceType.MacOsCLI,
DeviceType.LinuxCLI, DeviceType.LinuxCLI
}; ];
public static IReadOnlyCollection<DeviceType> BrowserExtensionTypes { get; } =
[
DeviceType.ChromeExtension,
DeviceType.FirefoxExtension,
DeviceType.OperaExtension,
DeviceType.EdgeExtension,
DeviceType.VivaldiExtension,
DeviceType.SafariExtension
];
public static IReadOnlyCollection<DeviceType> BrowserTypes { get; } =
[
DeviceType.ChromeBrowser,
DeviceType.FirefoxBrowser,
DeviceType.OperaBrowser,
DeviceType.EdgeBrowser,
DeviceType.IEBrowser,
DeviceType.SafariBrowser,
DeviceType.VivaldiBrowser,
DeviceType.UnknownBrowser
];
private static ClientType ToClientType(DeviceType? deviceType)
{
return deviceType switch
{
not null when MobileTypes.Contains(deviceType.Value) => ClientType.Mobile,
not null when DesktopTypes.Contains(deviceType.Value) => ClientType.Desktop,
not null when BrowserExtensionTypes.Contains(deviceType.Value) => ClientType.Browser,
not null when BrowserTypes.Contains(deviceType.Value) => ClientType.Web,
_ => ClientType.All
};
}
} }

View File

@ -128,7 +128,7 @@ public class ClientStore : IClientStore
Claims = new List<ClientClaim> Claims = new List<ClientClaim>
{ {
new(JwtClaimTypes.Subject, apiKey.ServiceAccountId.ToString()), new(JwtClaimTypes.Subject, apiKey.ServiceAccountId.ToString()),
new(Claims.Type, ClientType.ServiceAccount.ToString()), new(Claims.Type, IdentityClientType.ServiceAccount.ToString()),
}, },
}; };
@ -160,7 +160,7 @@ public class ClientStore : IClientStore
{ {
new(JwtClaimTypes.Subject, user.Id.ToString()), new(JwtClaimTypes.Subject, user.Id.ToString()),
new(JwtClaimTypes.AuthenticationMethod, "Application", "external"), new(JwtClaimTypes.AuthenticationMethod, "Application", "external"),
new(Claims.Type, ClientType.User.ToString()), new(Claims.Type, IdentityClientType.User.ToString()),
}; };
var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id); var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id);
var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, user.Id); var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, user.Id);
@ -218,7 +218,7 @@ public class ClientStore : IClientStore
Claims = new List<ClientClaim> Claims = new List<ClientClaim>
{ {
new(JwtClaimTypes.Subject, org.Id.ToString()), new(JwtClaimTypes.Subject, org.Id.ToString()),
new(Claims.Type, ClientType.Organization.ToString()), new(Claims.Type, IdentityClientType.Organization.ToString()),
}, },
}; };
} }

View File

@ -1,7 +1,7 @@
#nullable enable #nullable enable
using System.Data; using System.Data;
using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Entities; using Bit.Core.NotificationCenter.Entities;
using Bit.Core.NotificationCenter.Enums;
using Bit.Core.NotificationCenter.Models.Filter; using Bit.Core.NotificationCenter.Models.Filter;
using Bit.Core.NotificationCenter.Repositories; using Bit.Core.NotificationCenter.Repositories;
using Bit.Core.Settings; using Bit.Core.Settings;

View File

@ -1,6 +1,6 @@
#nullable enable #nullable enable
using AutoMapper; using AutoMapper;
using Bit.Core.NotificationCenter.Enums; using Bit.Core.Enums;
using Bit.Core.NotificationCenter.Models.Filter; using Bit.Core.NotificationCenter.Models.Filter;
using Bit.Core.NotificationCenter.Repositories; using Bit.Core.NotificationCenter.Repositories;
using Bit.Infrastructure.EntityFramework.NotificationCenter.Models; using Bit.Infrastructure.EntityFramework.NotificationCenter.Models;

View File

@ -115,12 +115,12 @@ public class ProjectsControllerTests
var resultProject = data.ToProject(orgId); var resultProject = data.ToProject(orgId);
sutProvider.GetDependency<ICreateProjectCommand>().CreateAsync(default, default, sutProvider.GetDependency<ICurrentContext>().ClientType) sutProvider.GetDependency<ICreateProjectCommand>().CreateAsync(default, default, sutProvider.GetDependency<ICurrentContext>().IdentityClientType)
.ReturnsForAnyArgs(resultProject); .ReturnsForAnyArgs(resultProject);
await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.CreateAsync(orgId, data)); await Assert.ThrowsAsync<NotFoundException>(() => sutProvider.Sut.CreateAsync(orgId, data));
await sutProvider.GetDependency<ICreateProjectCommand>().DidNotReceiveWithAnyArgs() await sutProvider.GetDependency<ICreateProjectCommand>().DidNotReceiveWithAnyArgs()
.CreateAsync(Arg.Any<Project>(), Arg.Any<Guid>(), sutProvider.GetDependency<ICurrentContext>().ClientType); .CreateAsync(Arg.Any<Project>(), Arg.Any<Guid>(), sutProvider.GetDependency<ICurrentContext>().IdentityClientType);
} }
[Theory] [Theory]
@ -138,7 +138,7 @@ public class ProjectsControllerTests
await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.CreateAsync(orgId, data)); await Assert.ThrowsAsync<BadRequestException>(() => sutProvider.Sut.CreateAsync(orgId, data));
await sutProvider.GetDependency<ICreateProjectCommand>().DidNotReceiveWithAnyArgs() await sutProvider.GetDependency<ICreateProjectCommand>().DidNotReceiveWithAnyArgs()
.CreateAsync(Arg.Any<Project>(), Arg.Any<Guid>(), sutProvider.GetDependency<ICurrentContext>().ClientType); .CreateAsync(Arg.Any<Project>(), Arg.Any<Guid>(), sutProvider.GetDependency<ICurrentContext>().IdentityClientType);
} }
[Theory] [Theory]
@ -153,13 +153,13 @@ public class ProjectsControllerTests
var resultProject = data.ToProject(orgId); var resultProject = data.ToProject(orgId);
sutProvider.GetDependency<ICreateProjectCommand>().CreateAsync(default, default, sutProvider.GetDependency<ICurrentContext>().ClientType) sutProvider.GetDependency<ICreateProjectCommand>().CreateAsync(default, default, sutProvider.GetDependency<ICurrentContext>().IdentityClientType)
.ReturnsForAnyArgs(resultProject); .ReturnsForAnyArgs(resultProject);
await sutProvider.Sut.CreateAsync(orgId, data); await sutProvider.Sut.CreateAsync(orgId, data);
await sutProvider.GetDependency<ICreateProjectCommand>().Received(1) await sutProvider.GetDependency<ICreateProjectCommand>().Received(1)
.CreateAsync(Arg.Any<Project>(), Arg.Any<Guid>(), sutProvider.GetDependency<ICurrentContext>().ClientType); .CreateAsync(Arg.Any<Project>(), Arg.Any<Guid>(), sutProvider.GetDependency<ICurrentContext>().IdentityClientType);
} }
[Theory] [Theory]